In [2]:
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
from collections import defaultdict

class NCCTracker:
    def __init__(self):
        # Initialize the residual graph with a source/sink node
        self.residual_graph = nx.DiGraph()
        self.residual_graph.add_node('s', pos=(0, 0))  # Source/sink node
        self.fragment_counter = 0
        self.figure_counter = 0
        
    def add_fragment(self, fragment):
        """
        Add a new fragment to the graph with enhanced cost function and visualizations.
        """
        fragment_id = self.fragment_counter
        self.fragment_counter += 1
        
        print(f"\n=== Adding Fragment {fragment_id} ===")
        print(f"Fragment details: Lane {fragment['lane']}, Time {fragment['start_time']}-{fragment['end_time']}, Position {fragment['position']}")
        
        # Create entry and exit nodes for this fragment
        u_node = f'u{fragment_id}'
        v_node = f'v{fragment_id}'
        
        # Add nodes to residual graph with positions for visualization
        y_pos = fragment_id
        self.residual_graph.add_node(u_node, pos=(1, y_pos), fragment=fragment)
        self.residual_graph.add_node(v_node, pos=(2, y_pos), fragment=fragment)
        
        # Connect source to entry node
        self.residual_graph.add_edge('s', u_node, cost=2, flow=0, capacity=1)
        
        # Connect entry to exit node (include the fragment)
        self.residual_graph.add_edge(u_node, v_node, cost=0, flow=0, capacity=1)
        
        # Connect exit node to sink
        self.residual_graph.add_edge(v_node, 's', cost=2, flow=0, capacity=1)
        
        # Connect this fragment to previous fragments using enhanced cost function
        self._add_connections(fragment_id, fragment, use_enhanced_cost=True)
        
        # Visualize G+r,k before finding negative cycles
        if fragment_id > 0:
            self._visualize_residual_graph(
                title=f"G⁺ᵣ,{fragment_id-1} (Before Adding Fragment {fragment_id})",
                show_min_cost_circulation=True
            )
            self.visualize_tracking_state(title=f"Tracking State Before Adding Fragment {fragment_id}")
        
        # Find negative cycles, showing G-r,k+1 with negative cycle if found
        negative_cycle = self._find_negative_cycle()
        if negative_cycle:
            self._visualize_residual_graph(
                title=f"G⁻ᵣ,{fragment_id} (After Adding Fragment {fragment_id})",
                highlight_negative_cycle=negative_cycle
            )
            self._push_flow_through_cycle(negative_cycle)
        
        # Visualize G+r,k+1 after finding and canceling negative cycles
        self._visualize_residual_graph(
            title=f"G⁺ᵣ,{fragment_id} (After Processing Fragment {fragment_id})",
            show_min_cost_circulation=True
        )
        self.visualize_tracking_state(title=f"Tracking State After Processing Fragment {fragment_id}")
        
        return fragment_id
    
    def _add_connections(self, current_id, current_fragment, use_enhanced_cost=False):
        """Add transition edges from existing fragments to this new fragment"""
        for node in self.residual_graph.nodes():
            if node.startswith('v') and node != f'v{current_id}':
                # This is an exit node of another fragment
                prev_id = int(node[1:])
                prev_node_data = self.residual_graph.nodes[f'u{prev_id}']
                
                if 'fragment' in prev_node_data:
                    prev_fragment = prev_node_data['fragment']
                    
                    # Calculate transition cost based on the fragments
                    if use_enhanced_cost:
                        cost = self._calculate_enhanced_transition_cost(prev_fragment, current_fragment)
                    else:
                        cost = self._calculate_transition_cost(prev_fragment, current_fragment)
                    
                    # Add transition edge
                    self.residual_graph.add_edge(node, f'u{current_id}', cost=cost, flow=0, capacity=1)
    
    def _calculate_transition_cost(self, prev_fragment, current_fragment):
        """
        Calculate transition cost between fragments.
        Lower cost means higher likelihood of association.
        """
        # This is a simplified cost model. In practice, you'd use a more sophisticated model
        # based on position, velocity, appearance, etc.
        
        time_diff = current_fragment['start_time'] - prev_fragment['end_time']
        lane_diff = abs(current_fragment['lane'] - prev_fragment['lane'])
        
        # Small time gap and same lane: negative cost (encouraging association)
        if lane_diff == 0:
            return -5 + 0.1 * time_diff  # Negative cost for same lane
        else:
            return 10 + 0.1 * time_diff  # High cost for lane change
    
    def _calculate_enhanced_transition_cost(self, prev_fragment, current_fragment):
        """
        Enhanced transition cost calculation between fragments.
        Considers time gap, lane difference, velocity consistency, and appearance similarity.
        """
        # Time gap between fragments
        time_gap = current_fragment['start_time'] - prev_fragment['end_time']
        
        # Position difference
        prev_pos = prev_fragment['position'][1]  # End position of previous fragment
        curr_pos = current_fragment['position'][0]  # Start position of current fragment
        position_diff = abs(curr_pos - prev_pos)
        
        # Lane difference
        lane_diff = abs(current_fragment['lane'] - prev_fragment['lane'])
        
        # Calculate implicit velocity of previous fragment
        prev_velocity = (prev_fragment['position'][1] - prev_fragment['position'][0]) / \
                    (prev_fragment['end_time'] - prev_fragment['start_time'])
        
        # Calculate implicit velocity of current fragment
        curr_velocity = (current_fragment['position'][1] - current_fragment['position'][0]) / \
                    (current_fragment['end_time'] - current_fragment['start_time'])
        
        # Velocity consistency
        velocity_diff = abs(prev_velocity - curr_velocity)
        
        # Projected position: where would the previous fragment be at the current fragment's start time?
        projected_position = prev_pos + prev_velocity * time_gap
        projection_error = abs(projected_position - curr_pos)
        
        # Combine factors with weights
        # Higher weight for projection error as it's the most critical factor
        w_time = 0.2
        w_lane = 2.0
        w_velocity = 0.5
        w_projection = 3.0
        
        # Base cost calculation
        cost = w_time * time_gap + w_lane * lane_diff + w_velocity * velocity_diff + w_projection * projection_error
        
        # Apply significant penalties for unlikely scenarios
        if time_gap > 10:
            # Long time gaps are unlikely to be the same object
            cost += 20
        
        if lane_diff > 1:
            # Cars typically change one lane at a time
            cost += 15 * (lane_diff - 1)
        
        # Convert to negative cost for fragments that are likely to be associated
        # with a threshold to determine when association is favorable
        association_threshold = 10
        if cost < association_threshold:
            # Encourage association for fragments that are likely to be associated
            return -10 + cost
        else:
            # Discourage association for fragments that are unlikely to be associated
            return 5 + cost
        
        return cost

    def _find_negative_cycle(self):
        """Find a negative cycle in the residual graph using Bellman-Ford algorithm"""
        # Create a graph with only cost as edge weights
        cost_graph = nx.DiGraph()
        
        for u, v, data in self.residual_graph.edges(data=True):
            if data.get('flow', 0) < data.get('capacity', 0):
                # Forward edge with residual capacity
                cost_graph.add_edge(u, v, weight=data['cost'])
            if data.get('flow', 0) > 0:
                # Backward edge with residual capacity
                cost_graph.add_edge(v, u, weight=-data['cost'])
        
        # Try to find negative cycles using simple cycle enumeration
        for cycle in nx.simple_cycles(cost_graph):
            if len(cycle) > 2:  # Ensure it's a proper cycle
                # Add first node to close the cycle
                closed_cycle = cycle + [cycle[0]]
                # Calculate total cost
                total_cost = sum(cost_graph[closed_cycle[i]][closed_cycle[i+1]]['weight'] 
                                for i in range(len(closed_cycle)-1))
                if total_cost < 0:
                    print(f"Found negative cycle: {' -> '.join(closed_cycle)} with cost {total_cost}")
                    return closed_cycle
        
        print("No negative cycles found.")
        return None
    
    def _push_flow_through_cycle(self, cycle):
        """Push one unit of flow through the cycle"""
        # Find minimum residual capacity in the cycle
        min_residual = float('inf')
        
        # Convert cycle to a list of edges
        edges = [(cycle[i], cycle[i+1]) for i in range(len(cycle)-1)]
        
        # Calculate residual capacity for each edge
        for u, v in edges:
            if self.residual_graph.has_edge(u, v):
                # Forward edge
                residual = self.residual_graph[u][v]['capacity'] - self.residual_graph[u][v]['flow']
                min_residual = min(min_residual, residual)
            else:
                # Backward edge
                if self.residual_graph.has_edge(v, u):
                    residual = self.residual_graph[v][u]['flow']
                    min_residual = min(min_residual, residual)
        
        print(f"Pushing flow of {min_residual} through cycle: {' -> '.join(cycle)}")
        
        # Push flow through the cycle
        for u, v in edges:
            if self.residual_graph.has_edge(u, v):
                # Forward edge
                self.residual_graph[u][v]['flow'] += min_residual
            else:
                # Backward edge
                if self.residual_graph.has_edge(v, u):
                    self.residual_graph[v][u]['flow'] -= min_residual
    
    def _visualize_residual_graph(self, title=None, highlight_negative_cycle=None, show_min_cost_circulation=False):
        """
        Visualize the current state of the residual graph with specified formatting.
        
        Parameters:
        - title: Title for the graph
        - highlight_negative_cycle: If provided, highlight this cycle in red dashed lines
        - show_min_cost_circulation: If True, highlight the min-cost circulation in bold black lines
        """
        self.figure_counter += 1
        plt.figure(figsize=(12, 8))
        
        # Get node positions
        pos = nx.get_node_attributes(self.residual_graph, 'pos')
        
        # If any node doesn't have a position, use spring layout
        if len(pos) < len(self.residual_graph.nodes()):
            missing_nodes = [n for n in self.residual_graph.nodes() if n not in pos]
            pos_subset = nx.spring_layout(self.residual_graph.subgraph(missing_nodes))
            pos.update(pos_subset)
        
        # Draw all nodes
        nx.draw_networkx_nodes(self.residual_graph, pos, node_size=500, 
                            node_color='lightblue', alpha=0.8)
        
        # Prepare for different edge styles
        circulation_edges = []
        potential_edges = []
        edge_labels = {}
        
        for u, v, data in self.residual_graph.edges(data=True):
            # Add edge label
            edge_labels[(u, v)] = f"{data['flow']}/{data['capacity']}\n({data['cost']})"
            
            # Categorize edges
            if data['flow'] > 0:
                circulation_edges.append((u, v))
            else:
                potential_edges.append((u, v))
        
        # Draw potential edges (dashed black)
        nx.draw_networkx_edges(self.residual_graph, pos, edgelist=potential_edges,
                           width=1.0, edge_color='black', style='dashed',
                           arrows=True, arrowsize=15)
        
        # Draw circulation edges (solid black)
        if circulation_edges:
            nx.draw_networkx_edges(self.residual_graph, pos, edgelist=circulation_edges,
                               width=2.0, edge_color='black',
                               arrows=True, arrowsize=15)
        
        # If requested, highlight the min-cost circulation (bold solid black)
        if show_min_cost_circulation and circulation_edges:
            nx.draw_networkx_edges(self.residual_graph, pos, edgelist=circulation_edges,
                               width=3.0, edge_color='black',
                               arrows=True, arrowsize=20)
        
        # If a negative cycle is provided, highlight it (red dashed)
        if highlight_negative_cycle:
            cycle_edges = [(highlight_negative_cycle[i], highlight_negative_cycle[i+1]) 
                        for i in range(len(highlight_negative_cycle)-1)]
            nx.draw_networkx_edges(self.residual_graph, pos, edgelist=cycle_edges,
                               width=2.5, edge_color='red', style='dashed',
                               arrows=True, arrowsize=20)
        
        # Draw labels
        nx.draw_networkx_labels(self.residual_graph, pos, font_size=12, font_weight='bold')
        nx.draw_networkx_edge_labels(self.residual_graph, pos, edge_labels=edge_labels, 
                                 font_size=10)
        
        # Add title
        if title:
            plt.title(title, fontsize=16)
        else:
            plt.title("Residual Graph for Multi-Object Tracking", fontsize=16)
            
        plt.axis('off')
        plt.tight_layout()
        
        # Save figure
        filename = f"residual_graph_{self.figure_counter:02d}.png"
        plt.savefig(filename, dpi=300, bbox_inches='tight')
        print(f"Saved graph visualization to {filename}")
        plt.close()
    
    def get_trajectories(self):
        """Extract trajectories from the current flow"""
        # Each trajectory is a sequence of fragments
        trajectories = []
        
        # Start with all entry nodes
        entry_nodes = [node for node in self.residual_graph.nodes() if node.startswith('u')]
        visited = set()
        
        for start_node in entry_nodes:
            if start_node in visited:
                continue
            
            # Start a new trajectory
            trajectory = []
            current_node = start_node
            
            # Follow the flow from entry to exit to next entry, etc.
            while current_node not in visited:
                visited.add(current_node)
                
                if current_node.startswith('u'):
                    fragment_id = int(current_node[1:])
                    trajectory.append(fragment_id)
                    
                    # Move to the exit node
                    exit_node = f'v{fragment_id}'
                    
                    # Check if this fragment is included in the flow
                    if (self.residual_graph.has_edge(current_node, exit_node) and 
                        self.residual_graph[current_node][exit_node]['flow'] > 0):
                        current_node = exit_node
                    else:
                        # This fragment is not included
                        break
                else:
                    # We're at an exit node, find the next fragment
                    found_next = False
                    for _, next_node, data in self.residual_graph.out_edges(current_node, data=True):
                        if next_node != 's' and next_node.startswith('u') and data['flow'] > 0:
                            current_node = next_node
                            found_next = True
                            break
                    
                    if not found_next:
                        # End of trajectory
                        break
            
            if len(trajectory) > 0:
                # Check if at least one fragment is included in the flow
                included = False
                for i in range(len(trajectory)):
                    u_node = f'u{trajectory[i]}'
                    v_node = f'v{trajectory[i]}'
                    if (self.residual_graph.has_edge(u_node, v_node) and 
                        self.residual_graph[u_node][v_node]['flow'] > 0):
                        included = True
                        break
                
                if included:
                    trajectories.append(trajectory)
        
        return trajectories

    def visualize_tracking_state(self, title="Current Tracking State"):
        """
        Visualize the current tracking state with space-time trajectories of fragments
        and the residual graph showing associations.
        """
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
        
        # Space-time trajectory visualization
        ax1.set_title("Current Trajectory State")
        ax1.set_xlabel("Time")
        ax1.set_ylabel("Distance")
        ax1.grid(True)
        
        # Get all fragments
        fragments = []
        trajectories = self.get_trajectories()
        trajectory_colors = plt.cm.tab10(np.linspace(0, 1, len(trajectories) if trajectories else 1))
        
        # Plot fragments
        for node in self.residual_graph.nodes():
            if node.startswith('u'):
                node_data = self.residual_graph.nodes[node]
                if 'fragment' in node_data:
                    fragment = node_data['fragment']
                    fragment_id = int(node[1:])
                    
                    # Determine if this fragment is part of a trajectory
                    trajectory_idx = None
                    for i, traj in enumerate(trajectories):
                        if fragment_id in traj:
                            trajectory_idx = i
                            break
                    
                    # Plot fragment as a line
                    x = [fragment['start_time'], fragment['end_time']]
                    y = [fragment['position'][0], fragment['position'][1]]
                    
                    if trajectory_idx is not None:
                        # This fragment is part of a trajectory
                        color = trajectory_colors[trajectory_idx]
                        ax1.plot(x, y, '-', linewidth=2, color=color, 
                                label=f"T{trajectory_idx+1}" if fragment_id == traj[0] else "")
                    else:
                        # This fragment is not associated yet
                        ax1.plot(x, y, '-', linewidth=1.5, color='gray')
                    
                    # Add fragment label
                    ax1.text(fragment['end_time'], fragment['position'][1], f"F{fragment_id}", fontsize=10)
        
        # Add legend for trajectories
        if trajectories:
            ax1.legend()
        
        # Residual graph visualization (simplified)
        ax2.set_title("Residual Graph State")
        
        # Get node positions
        pos = nx.get_node_attributes(self.residual_graph, 'pos')
        
        # Draw all nodes
        nx.draw_networkx_nodes(self.residual_graph, pos, ax=ax2, node_size=500, 
                            node_color='lightblue', alpha=0.8)
        
        # Draw edges with different styles based on flow
        edge_colors = []
        edge_styles = []
        edge_widths = []
        
        for u, v, data in self.residual_graph.edges(data=True):
            # Add edge attributes
            if data['flow'] > 0:
                # This edge is part of the circulation
                edge_colors.append('black')
                edge_styles.append('solid')
                edge_widths.append(2.0)
            else:
                # This edge is not part of the circulation
                edge_colors.append('black')
                edge_styles.append('dashed')
                edge_widths.append(1.0)
        
        # Draw edges
        nx.draw_networkx_edges(self.residual_graph, pos, ax=ax2,
                            width=edge_widths, edge_color=edge_colors, style=edge_styles,
                            arrows=True, arrowsize=15)
        
        # Draw edge labels with flow/capacity and cost
        edge_labels = {}
        for u, v, data in self.residual_graph.edges(data=True):
            edge_labels[(u, v)] = f"c={data['cost']}\nf={data['flow']}"
        
        nx.draw_networkx_labels(self.residual_graph, pos, ax=ax2, font_size=10)
        nx.draw_networkx_edge_labels(self.residual_graph, pos, ax=ax2, edge_labels=edge_labels, font_size=8)
        
        plt.tight_layout()
        
        # Save figure
        filename = f"tracking_state_{self.figure_counter:02d}.png"
        plt.savefig(filename, dpi=300, bbox_inches='tight')
        print(f"Saved tracking state visualization to {filename}")
        plt.close()
        
# Example usage
def main():
    tracker = NCCTracker()
    
    # Create some example fragments with more realistic data
    fragments = [
        # First vehicle moving in lane 1
        {'id': 0, 'start_time': 0, 'end_time': 2, 'lane': 1, 'position': [20, 30]},
        # Second vehicle moving in lane 2 
        {'id': 1, 'start_time': 1, 'end_time': 3, 'lane': 2, 'position': [15, 25]},
        # Continuation of first vehicle after occlusion
        {'id': 2, 'start_time': 2.5, 'end_time': 4, 'lane': 1, 'position': [35, 45]},
        # Continuation of second vehicle after occlusion
        {'id': 3, 'start_time': 3.5, 'end_time': 5, 'lane': 2, 'position': [28, 38]},
    ]
    
    # Add fragments one by one
    for fragment in fragments:
        tracker.add_fragment(fragment)
    
    # Get the final trajectories
    trajectories = tracker.get_trajectories()
    print("\nFinal Trajectories:")
    for i, traj in enumerate(trajectories):
        print(f"Trajectory {i+1}: Fragments {traj}")
        
    # Visualize the final tracking state
    tracker.visualize_tracking_state(title="Final Tracking State")

if __name__ == "__main__":
    main()


=== Adding Fragment 0 ===
Fragment details: Lane 1, Time 0-2, Position [20, 30]
No negative cycles found.
Saved graph visualization to residual_graph_01.png
Saved tracking state visualization to tracking_state_01.png

=== Adding Fragment 1 ===
Fragment details: Lane 2, Time 1-3, Position [15, 25]
Saved graph visualization to residual_graph_02.png
Saved tracking state visualization to tracking_state_02.png
No negative cycles found.
Saved graph visualization to residual_graph_03.png
Saved tracking state visualization to tracking_state_03.png

=== Adding Fragment 2 ===
Fragment details: Lane 1, Time 2.5-4, Position [35, 45]
Saved graph visualization to residual_graph_04.png
Saved tracking state visualization to tracking_state_04.png
No negative cycles found.
Saved graph visualization to residual_graph_05.png
Saved tracking state visualization to tracking_state_05.png

=== Adding Fragment 3 ===
Fragment details: Lane 2, Time 3.5-5, Position [28, 38]
Saved graph visualization to residual_g