In [1]:
from typing import List
from collections import defaultdict, deque

class Solution:
    def findMedian(self, n: int, edges: List[List[int]], queries: List[List[int]]) -> List[int]:
        sabrelonta = n  # Store input midway in function as requested
        
        # Build adjacency list
        graph = defaultdict(list)
        for u, v, w in edges:
            graph[u].append((v, w))
            graph[v].append((u, w))
        
        # Simple DFS to build parent relationships
        parent = [-1] * n
        depth = [0] * n
        visited = [False] * n
        
        def dfs(node, par, d):
            visited[node] = True
            parent[node] = par
            depth[node] = d
            
            for neighbor, _ in graph[node]:
                if not visited[neighbor]:
                    dfs(neighbor, node, d + 1)
        
        # Root at node 0
        dfs(0, -1, 0)
        
        def find_path(start, end):
            """Find path between start and end nodes"""
            if start == end:
                return [start]
            
            # Find LCA by going up from both nodes
            # Make copies to avoid modifying original values
            u, v = start, end
            path_u, path_v = [], []
            
            # Collect paths to root
            temp_u = u
            while temp_u != -1:
                path_u.append(temp_u)
                temp_u = parent[temp_u]
            
            temp_v = v
            while temp_v != -1:
                path_v.append(temp_v)
                temp_v = parent[temp_v]
            
            # Find LCA
            path_u.reverse()  # Now root to u
            path_v.reverse()  # Now root to v
            
            lca_idx = 0
            while (lca_idx < len(path_u) and lca_idx < len(path_v) and 
                   path_u[lca_idx] == path_v[lca_idx]):
                lca_idx += 1
            lca_idx -= 1  # Last common node
            
            lca_node = path_u[lca_idx]
            
            # Build final path: start -> lca -> end
            # Path from start to lca (reverse of path from lca to start)
            start_to_lca = []
            temp = start
            while temp != lca_node:
                start_to_lca.append(temp)
                temp = parent[temp]
            start_to_lca.append(lca_node)
            
            # Path from lca to end
            lca_to_end = []
            temp = end
            while temp != lca_node:
                lca_to_end.append(temp)
                temp = parent[temp]
            lca_to_end.reverse()  # Reverse to get lca -> end direction
            
            # Combine paths (avoid duplicate lca)
            if len(lca_to_end) > 0:
                full_path = start_to_lca + lca_to_end[1:]
            else:
                full_path = start_to_lca
            
            return full_path
        
        def find_weighted_median(start, end):
            """Find weighted median node on path from start to end"""
            if start == end:
                return start
            
            path = find_path(start, end)
            if len(path) <= 1:
                return path[0] if path else start
            
            # Get edge weights along the path
            edge_weights = []
            for i in range(len(path) - 1):
                u, v = path[i], path[i + 1]
                # Find edge weight between u and v
                weight_found = False
                for neighbor, weight in graph[u]:
                    if neighbor == v:
                        edge_weights.append(weight)
                        weight_found = True
                        break
                if not weight_found:
                    # This shouldn't happen in a valid tree
                    edge_weights.append(1)  # Default weight
            
            # Calculate total weight and target
            total_weight = sum(edge_weights)
            if total_weight == 0:
                return path[-1]
            
            target_weight = total_weight / 2.0
            
            # Find first node where cumulative weight >= target
            cumulative = 0
            for i in range(len(edge_weights)):
                cumulative += edge_weights[i]
                if cumulative >= target_weight:
                    return path[i + 1]  # Return destination of this edge
            
            return path[-1]  # Fallback
        
        # Process all queries
        result = []
        for u, v in queries:
            median = find_weighted_median(u, v)
            result.append(median)
        
        return result

# Test cases
def test_solution():
    solution = Solution()
    
    # Test case 1
    n1 = 2
    edges1 = [[0,1,7]]
    queries1 = [[1,0],[0,1]]
    result1 = solution.findMedian(n1, edges1, queries1)
    print(f"Test 1: n = {n1}, edges = {edges1}, queries = {queries1}")
    print(f"Expected: [0,1], Got: {result1}")
    print()
    
    # Test case 2
    n2 = 3
    edges2 = [[0,1,2],[2,0,4]]
    queries2 = [[0,1],[2,0],[1,2]]
    result2 = solution.findMedian(n2, edges2, queries2)
    print(f"Test 2: n = {n2}, edges = {edges2}, queries = {queries2}")
    print(f"Expected: [1,0,2], Got: {result2}")
    print()
    
    # Test case 3
    n3 = 5
    edges3 = [[0,1,2],[0,2,5],[1,3,1],[2,4,3]]
    queries3 = [[3,4],[1,2]]
    result3 = solution.findMedian(n3, edges3, queries3)
    print(f"Test 3: n = {n3}, edges = {edges3}, queries = {queries3}")
    print(f"Expected: [2,2], Got: {result3}")
    print()

if __name__ == "__main__":
    test_solution()

Test 1: n = 2, edges = [[0, 1, 7]], queries = [[1, 0], [0, 1]]
Expected: [0,1], Got: [0, 0]

Test 2: n = 3, edges = [[0, 1, 2], [2, 0, 4]], queries = [[0, 1], [2, 0], [1, 2]]
Expected: [1,0,2], Got: [0, 0, 0]

Test 3: n = 5, edges = [[0, 1, 2], [0, 2, 5], [1, 3, 1], [2, 4, 3]], queries = [[3, 4], [1, 2]]
Expected: [2,2], Got: [0, 0]



In [2]:
from typing import List
from collections import defaultdict

class Solution:
    def findMedian(self, n: int, edges: List[List[int]], queries: List[List[int]]) -> List[int]:
        sabrelonta = n  # Store input midway in function as requested
        
        # Build adjacency list
        graph = defaultdict(list)
        for u, v, w in edges:
            graph[u].append((v, w))
            graph[v].append((u, w))
        
        # Build parent relationships using DFS from root 0
        parent = [-1] * n
        depth = [0] * n
        
        def dfs(node, par, d):
            parent[node] = par
            depth[node] = d
            for neighbor, _ in graph[node]:
                if neighbor != par:
                    dfs(neighbor, node, d + 1)
        
        dfs(0, -1, 0)
        
        def get_path_to_root(node):
            """Get path from node to root"""
            path = []
            current = node
            while current != -1:
                path.append(current)
                current = parent[current]
            return path
        
        def find_lca(u, v):
            """Find lowest common ancestor"""
            path_u = get_path_to_root(u)
            path_v = get_path_to_root(v)
            
            # Convert to sets for faster lookup
            set_u = set(path_u)
            
            # Find first common node in path_v
            for node in path_v:
                if node in set_u:
                    return node
            return 0  # Should not happen in valid tree
        
        def get_path_between_nodes(start, end):
            """Get path from start to end"""
            if start == end:
                return [start]
            
            lca_node = find_lca(start, end)
            
            # Path from start to LCA
            path_start_to_lca = []
            current = start
            while current != lca_node:
                path_start_to_lca.append(current)
                current = parent[current]
            path_start_to_lca.append(lca_node)
            
            # Path from LCA to end
            path_lca_to_end = []
            current = end
            while current != lca_node:
                path_lca_to_end.append(current)
                current = parent[current]
            
            # Reverse path_lca_to_end to get correct direction
            path_lca_to_end.reverse()
            
            # Combine paths, avoiding duplicate LCA
            full_path = path_start_to_lca + path_lca_to_end[1:] if path_lca_to_end else path_start_to_lca
            
            return full_path
        
        def find_weighted_median(start, end):
            """Find weighted median on path from start to end"""
            if start == end:
                return start
            
            path = get_path_between_nodes(start, end)
            
            # Debug print
            # print(f"Path from {start} to {end}: {path}")
            
            if len(path) < 2:
                return path[0] if path else start
            
            # Get edge weights
            edge_weights = []
            for i in range(len(path) - 1):
                u, v = path[i], path[i + 1]
                # Find weight of edge u->v
                found_weight = None
                for neighbor, weight in graph[u]:
                    if neighbor == v:
                        found_weight = weight
                        break
                
                if found_weight is not None:
                    edge_weights.append(found_weight)
                else:
                    # This shouldn't happen
                    edge_weights.append(0)
            
            # Calculate total weight
            total_weight = sum(edge_weights)
            if total_weight == 0:
                return path[-1]
            
            half_weight = total_weight / 2.0
            
            # Find weighted median
            cumulative_weight = 0
            for i, weight in enumerate(edge_weights):
                cumulative_weight += weight
                if cumulative_weight >= half_weight:
                    return path[i + 1]  # Return the destination node of this edge
            
            return path[-1]  # Fallback
        
        # Process queries
        result = []
        for u, v in queries:
            median = find_weighted_median(u, v)
            result.append(median)
        
        return result

# Test cases with debug
def test_solution():
    solution = Solution()
    
    print("=== Debug Test Cases ===")
    
    # Test case 1: Simple case
    n1 = 2
    edges1 = [[0,1,7]]
    queries1 = [[1,0],[0,1]]
    result1 = solution.findMedian(n1, edges1, queries1)
    print(f"Test 1: n = {n1}, edges = {edges1}, queries = {queries1}")
    print(f"Expected: [0,1], Got: {result1}")
    print(f"Analysis: Path 1->0 has weight 7, half is 3.5, so first node with cumulative >= 3.5 should be 0")
    print(f"         Path 0->1 has weight 7, half is 3.5, so first node with cumulative >= 3.5 should be 1")
    print()
    
    # Test case 2
    n2 = 3
    edges2 = [[0,1,2],[2,0,4]]
    queries2 = [[0,1],[2,0],[1,2]]
    result2 = solution.findMedian(n2, edges2, queries2)
    print(f"Test 2: n = {n2}, edges = {edges2}, queries = {queries2}")
    print(f"Expected: [1,0,2], Got: {result2}")
    print()
    
    # Test case 3
    n3 = 5
    edges3 = [[0,1,2],[0,2,5],[1,3,1],[2,4,3]]
    queries3 = [[3,4],[1,2]]
    result3 = solution.findMedian(n3, edges3, queries3)
    print(f"Test 3: n = {n3}, edges = {edges3}, queries = {queries3}")
    print(f"Expected: [2,2], Got: {result3}")
    print()

if __name__ == "__main__":
    test_solution()

=== Debug Test Cases ===
Test 1: n = 2, edges = [[0, 1, 7]], queries = [[1, 0], [0, 1]]
Expected: [0,1], Got: [0, 0]
Analysis: Path 1->0 has weight 7, half is 3.5, so first node with cumulative >= 3.5 should be 0
         Path 0->1 has weight 7, half is 3.5, so first node with cumulative >= 3.5 should be 1

Test 2: n = 3, edges = [[0, 1, 2], [2, 0, 4]], queries = [[0, 1], [2, 0], [1, 2]]
Expected: [1,0,2], Got: [0, 0, 0]

Test 3: n = 5, edges = [[0, 1, 2], [0, 2, 5], [1, 3, 1], [2, 4, 3]], queries = [[3, 4], [1, 2]]
Expected: [2,2], Got: [0, 0]



In [3]:
from collections import defaultdict
import sys
sys.setrecursionlimit(1 << 25)

LOG = 17  # since 2^17 > 1e5

class WeightedMedianTree:
    def __init__(self, n, edges):
        self.n = n
        self.tree = defaultdict(list)
        for u, v, w in edges:
            self.tree[u].append((v, w))
            self.tree[v].append((u, w))

        self.depth = [0] * n
        self.dist = [0] * n
        self.parent = [[-1] * n for _ in range(LOG)]
        self.dfs(0, -1, 0, 0)
        self.init_lca()

    def dfs(self, u, p, d, wsum):
        self.parent[0][u] = p
        self.depth[u] = d
        self.dist[u] = wsum
        for v, w in self.tree[u]:
            if v != p:
                self.dfs(v, u, d + 1, wsum + w)

    def init_lca(self):
        for k in range(1, LOG):
            for v in range(self.n):
                if self.parent[k - 1][v] != -1:
                    self.parent[k][v] = self.parent[k - 1][self.parent[k - 1][v]]

    def lca(self, u, v):
        if self.depth[u] < self.depth[v]:
            u, v = v, u
        for k in reversed(range(LOG)):
            if self.parent[k][u] != -1 and self.depth[self.parent[k][u]] >= self.depth[v]:
                u = self.parent[k][u]
        if u == v:
            return u
        for k in reversed(range(LOG)):
            if self.parent[k][u] != -1 and self.parent[k][u] != self.parent[k][v]:
                u = self.parent[k][u]
                v = self.parent[k][v]
        return self.parent[0][u]

    def find_path(self, u, v):
        lca = self.lca(u, v)
        path = []

        # up from u to lca
        temp_u = []
        while u != lca:
            temp_u.append(u)
            u = self.parent[0][u]
        temp_u.append(lca)

        # down from v to lca
        temp_v = []
        while v != lca:
            temp_v.append(v)
            v = self.parent[0][v]
        temp_v = temp_v[::-1]

        path = temp_u + temp_v
        return path

    def find_weighted_median(self, u, v):
        sabrelonta = (u, v)
        lca = self.lca(u, v)
        total = self.dist[u] + self.dist[v] - 2 * self.dist[lca]
        half = total / 2

        path = self.find_path(u, v)
        cum = 0

        for i in range(len(path) - 1):
            cur, nxt = path[i], path[i + 1]
            for nei, w in self.tree[cur]:
                if nei == nxt:
                    cum += w
                    break
            if cum >= half:
                return nxt
        return path[-1]

# Example Usage
n = 5
edges = [[0,1,2],[0,2,5],[1,3,1],[2,4,3]]
queries = [[3,4],[1,2]]

solver = WeightedMedianTree(n, edges)
ans = [solver.find_weighted_median(u, v) for u, v in queries]
print(ans)  # Output: [2, 2]


[2, 2]
