In [None]:
# Overview of Disjoint Set

![image.png](attachment:image.png)

![image.png](attachment:image.png)

In [1]:
# UnionFind class
class UnionFind:
    def __init__(self, size):
        self.root = [i for i in range(size)]

    def find(self, x):
        return self.root[x]

    def union(self, x, y):
        rootX = self.find(x)
        rootY = self.find(y)
        if rootX != rootY:
            for i in range(len(self.root)):
                if self.root[i] == rootY:
                    self.root[i] = rootX

    def connected(self, x, y):
        return self.find(x) == self.find(y)


# Test Case
uf = UnionFind(10)
# 1-2-5-6-7 3-8-9 4
uf.union(1, 2)
uf.union(2, 5)
uf.union(5, 6)
uf.union(6, 7)
uf.union(3, 8)
uf.union(8, 9)
print(uf.connected(1, 5))  # true
print(uf.connected(5, 7))  # true
print(uf.connected(4, 9))  # false
# 1-2-5-6-7 3-8-9-4
uf.union(9, 4)
print(uf.connected(4, 9))  # true

True
True
False
True


![image.png](attachment:image.png)

In [None]:
# Quick Union - Disjoint Set

![image.png](attachment:image.png)

In [2]:
# UnionFind class
class UnionFind:
    def __init__(self, size):
        self.root = [i for i in range(size)]

    def find(self, x):
        while x != self.root[x]:
            x = self.root[x]
        return x

    def union(self, x, y):
        rootX = self.find(x)
        rootY = self.find(y)
        if rootX != rootY:
            self.root[rootY] = rootX

    def connected(self, x, y):
        return self.find(x) == self.find(y)


# Test Case
uf = UnionFind(10)
# 1-2-5-6-7 3-8-9 4
uf.union(1, 2)
uf.union(2, 5)
uf.union(5, 6)
uf.union(6, 7)
uf.union(3, 8)
uf.union(8, 9)
print(uf.connected(1, 5))  # true
print(uf.connected(5, 7))  # true
print(uf.connected(4, 9))  # false
# 1-2-5-6-7 3-8-9-4
uf.union(9, 4)
print(uf.connected(4, 9))  # true

True
True
False
True


![image.png](attachment:image.png)

![image.png](attachment:image.png)

In [None]:
# Union By Rank

In [5]:
# UnionFind class
class UnionFind:
    def __init__(self, size):
        self.root = [i for i in range(size)]
        self.rank = [1] * size

    def find(self, x):
        while x != self.root[x]:
            x = self.root[x]
        return x

    def union(self, x, y):
        rootX = self.find(x)
        rootY = self.find(y)
        if rootX != rootY:
            if self.rank[rootX] > self.rank[rootY]:
                self.root[rootY] = rootX
            elif self.rank[rootX] < self.rank[rootY]:
                self.root[rootX] = rootY
            else:
                self.root[rootY] = rootX
                self.rank[rootX] += 1

    def connected(self, x, y):
        return self.find(x) == self.find(y)

# Test Case
uf = UnionFind(10)
# 1-2-5-6-7 3-8-9 4
uf.union(1, 2)
uf.union(2, 5)
uf.union(5, 6)
uf.union(6, 7)
uf.union(3, 8)
uf.union(8, 9)
print(uf.connected(1, 5))  # true
print(uf.connected(5, 7))  # true
print(uf.connected(4, 9))  # false
# 1-2-5-6-7 3-8-9-4
uf.union(9, 4)
print(uf.connected(4, 9))  # true

True
True
False
True


In [None]:
# Path Compression Optimization 

![image.png](attachment:image.png)

In [6]:
# UnionFind class
class UnionFind:
    def __init__(self, size):
        self.root = [i for i in range(size)]

    def find(self, x):
        if x == self.root[x]:
            return x
        self.root[x] = self.find(self.root[x])
        return self.root[x]

    def union(self, x, y):
        rootX = self.find(x)
        rootY = self.find(y)
        if rootX != rootY:
            self.root[rootY] = rootX

    def connected(self, x, y):
        return self.find(x) == self.find(y)


# Test Case
uf = UnionFind(10)
# 1-2-5-6-7 3-8-9 4
uf.union(1, 2)
uf.union(2, 5)
uf.union(5, 6)
uf.union(6, 7)
uf.union(3, 8)
uf.union(8, 9)
print(uf.connected(1, 5))  # true
print(uf.connected(5, 7))  # true
print(uf.connected(4, 9))  # false
# 1-2-5-6-7 3-8-9-4
uf.union(9, 4)
print(uf.connected(4, 9))  # true

True
True
False
True


![image.png](attachment:image.png)

![image.png](attachment:image.png)

In [7]:
# UnionFind class
class UnionFind:
    def __init__(self, size):
        self.root = [i for i in range(size)]
        # Use a rank array to record the height of each vertex, i.e., the "rank" of each vertex.
        # The initial "rank" of each vertex is 1, because each of them is
        # a standalone vertex with no connection to other vertices.
        self.rank = [1] * size

    # The find function here is the same as that in the disjoint set with path compression.
    def find(self, x):
        if x == self.root[x]:
            return x
        self.root[x] = self.find(self.root[x])
        return self.root[x]

    # The union function with union by rank
    def union(self, x, y):
        rootX = self.find(x)
        rootY = self.find(y)
        if rootX != rootY:
            if self.rank[rootX] > self.rank[rootY]:
                self.root[rootY] = rootX
            elif self.rank[rootX] < self.rank[rootY]:
                self.root[rootX] = rootY
            else:
                self.root[rootY] = rootX
                self.rank[rootX] += 1

    def connected(self, x, y):
        return self.find(x) == self.find(y)


# Test Case
uf = UnionFind(10)
# 1-2-5-6-7 3-8-9 4
uf.union(1, 2)
uf.union(2, 5)
uf.union(5, 6)
uf.union(6, 7)
uf.union(3, 8)
uf.union(8, 9)
print(uf.connected(1, 5))  # true
print(uf.connected(5, 7))  # true
print(uf.connected(4, 9))  # false
# 1-2-5-6-7 3-8-9-4
uf.union(9, 4)
print(uf.connected(4, 9))  # true

True
True
False
True


![image.png](attachment:image.png)

![image.png](attachment:image.png)

## Implementation of the “disjoint set”

In [None]:

class UnionFind:
    # Constructor of Union-find. The size is the length of the root array.
    def __init__(self, size):
    def find(self, x):
    def union(self, x, y):
    def connected(self, x, y):

In [None]:
# A basic implementation of the find function:
def find(self, x):
    while x != self.root[x]:
        x = self.root[x]
    return x

In [None]:
# The find function – optimized with path compression:
def find(self, x):
    if x == self.root[x]:
        return x
    self.root[x] = self.find(self.root[x])
    return self.root[x]

## union function of the “disjoint set”

In [None]:
# A basic implementation of the union function:
def union(self, x, y):
    rootX = self.find(x)
    rootY = self.find(y)
    if rootX != rootY:
        self.root[rootY] = rootX

In [None]:
# The union function – Optimized by union by rank:
def union(self, x, y):
    rootX = self.find(x)
    rootY = self.find(y)
    if rootX != rootY:
        if self.rank[rootX] > self.rank[rootY]:
            self.root[rootY] = rootX
        elif self.rank[rootX] < self.rank[rootY]:
            self.root[rootX] = rootY
        else:
            self.root[rootY] = rootX
            self.rank[rootX] += 1

In [None]:
# connected function of the “disjoint set”
def connected(self, x, y):
    return self.find(x) == self.find(y)

# Number of Provinces

![image.png](attachment:image.png)

![image.png](attachment:image.png)

In [None]:
# Accepted V1
class Solution:
    def findCircleNum(self, isConnected: List[List[int]]) -> int:
        class UnionFind:
            # Constructor of Union-find. The size is the length of the root array.
            def __init__(self, size):
                self.root = [i for i in range(size)]
                self.rank = [1] * size
                
            # The find function – optimized with path compression:
            def find(self, x):
                if x == self.root[x]:
                    return x
                self.root[x] = self.find(self.root[x])
                return self.root[x]

            # The union function – Optimized by union by rank:
            def union(self, x, y):
                rootX = self.find(x)
                rootY = self.find(y)
                if rootX != rootY:
                    if self.rank[rootX] > self.rank[rootY]:
                        self.root[rootY] = rootX
                    elif self.rank[rootX] < self.rank[rootY]:
                        self.root[rootX] = rootY
                    else:
                        self.root[rootY] = rootX
                        self.rank[rootX] += 1
            
            # connected function of the “disjoint set”
            def connected(self, x, y):
                return self.find(x) == self.find(y)
            
        # Test Case
        length = len(isConnected)
        uf = UnionFind(length)
        for i in range(length):
            for j in range(i,length):
                if isConnected[i][j]:
                    uf.union(i, j)
        n = 0
        for i in range(length):
            if uf.root[i]==i:
                n+=1
        return n
            
            
            

In [None]:
# UnionFind class
class UnionFind:
    def __init__(self, size):
        self.root = [i for i in range(size)]
        # Use a rank array to record the height of each vertex, i.e., the "rank" of each vertex.
        # The initial "rank" of each vertex is 1, because each of them is
        # a standalone vertex with no connection to other vertices.
        self.rank = [1] * size
        self.count = size

    # The find function here is the same as that in the disjoint set with path compression.
    def find(self, x):
        if x == self.root[x]:
            return x
        self.root[x] = self.find(self.root[x])
        return self.root[x]

    # The union function with union by rank
    def union(self, x, y):
        rootX = self.find(x)
        rootY = self.find(y)
        if rootX != rootY:
            if self.rank[rootX] > self.rank[rootY]:
                self.root[rootY] = rootX
            elif self.rank[rootX] < self.rank[rootY]:
                self.root[rootX] = rootY
            else:
                self.root[rootY] = rootX
                self.rank[rootX] += 1
            self.count -= 1

    def getCount(self):
        return self.count


class Solution:
    def findCircleNum(self, isConnected: List[List[int]]) -> int:
        if not isConnected or len(isConnected) == 0:
            return 0
        n = len(isConnected)
        uf = UnionFind(n)
        for row in range(n):
            for col in range(row + 1, n):
                if isConnected[row][col] == 1:
                    uf.union(row, col)
        return uf.getCount()

In [8]:
# Graph Valid Tree

![image.png](attachment:image.png)

![image.png](attachment:image.png)

In [None]:
# Accepted V1
# UnionFind class
class UnionFind:
    def __init__(self, size):
        self.root = [i for i in range(size)]
        # Use a rank array to record the height of each vertex, i.e., the "rank" of each vertex.
        # The initial "rank" of each vertex is 1, because each of them is
        # a standalone vertex with no connection to other vertices.
        self.rank = [1] * size
        self.count = size

    # The find function here is the same as that in the disjoint set with path compression.
    def find(self, x):
        if x == self.root[x]:
            return x
        self.root[x] = self.find(self.root[x])
        return self.root[x]

    # The union function with union by rank
    def union(self, x, y):
        rootX = self.find(x)
        rootY = self.find(y)
        if rootX != rootY:
            if self.rank[rootX] > self.rank[rootY]:
                self.root[rootY] = rootX
            elif self.rank[rootX] < self.rank[rootY]:
                self.root[rootX] = rootY
            else:
                self.root[rootY] = rootX
                self.rank[rootX] += 1
            self.count -= 1

    def getCount(self):
        return self.count

class Solution:
    def validTree(self, n: int, edges: List[List[int]]) -> bool:
        uf = UnionFind(n)
        for x in edges:
            uf.union(x[0], x[1])
            
        province = 0
        for i in range(n):
            if uf.root[i]==i:
                province+=1
    
        return (len(edges)==n-1) & (province==1)
    

In [None]:
# Number of Connected Components in an Undirected Graph

![image.png](attachment:image.png)

![image.png](attachment:image.png)

In [None]:
# Accepted V1
# UnionFind class
class UnionFind:
    def __init__(self, size):
        self.root = [i for i in range(size)]
        # Use a rank array to record the height of each vertex, i.e., the "rank" of each vertex.
        # The initial "rank" of each vertex is 1, because each of them is
        # a standalone vertex with no connection to other vertices.
        self.rank = [1] * size
        self.count = size

    # The find function here is the same as that in the disjoint set with path compression.
    def find(self, x):
        if x == self.root[x]:
            return x
        self.root[x] = self.find(self.root[x])
        return self.root[x]

    # The union function with union by rank
    def union(self, x, y):
        rootX = self.find(x)
        rootY = self.find(y)
        if rootX != rootY:
            if self.rank[rootX] > self.rank[rootY]:
                self.root[rootY] = rootX
            elif self.rank[rootX] < self.rank[rootY]:
                self.root[rootX] = rootY
            else:
                self.root[rootY] = rootX
                self.rank[rootX] += 1
            self.count -= 1

    def getCount(self):
        return self.count

class Solution:
    def countComponents(self, n: int, edges: List[List[int]]) -> int:
        uf = UnionFind(n)
        for x in edges:
            uf.union(x[0], x[1])
            
        province = 0
        for i in range(n):
            if uf.root[i]==i:
                province+=1
    
        return province
    



In [None]:
# The Earliest Moment When Everyone Become Friends

![image.png](attachment:image.png)

![image.png](attachment:image.png)

In [None]:
# Accepted V1
# UnionFind class
class UnionFind:
    def __init__(self, size):
        self.root = [i for i in range(size)]
        # Use a rank array to record the height of each vertex, i.e., the "rank" of each vertex.
        # The initial "rank" of each vertex is 1, because each of them is
        # a standalone vertex with no connection to other vertices.
        self.rank = [1] * size
        self.count = size

    # The find function here is the same as that in the disjoint set with path compression.
    def find(self, x):
        if x == self.root[x]:
            return x
        self.root[x] = self.find(self.root[x])
        return self.root[x]

    # The union function with union by rank
    def union(self, x, y):
        rootX = self.find(x)
        rootY = self.find(y)
        if rootX != rootY:
            if self.rank[rootX] > self.rank[rootY]:
                self.root[rootY] = rootX
            elif self.rank[rootX] < self.rank[rootY]:
                self.root[rootX] = rootY
            else:
                self.root[rootY] = rootX
                self.rank[rootX] += 1
            self.count -= 1

    # connected function of the “disjoint set”
    def connected(self, x, y):
        return self.find(x) == self.find(y)

            
    def getCount(self):
        return self.count

class Solution:
    def earliestAcq(self, logs: List[List[int]], n: int) -> int:
        logs = sorted(logs, key=lambda x:x[0])
        
        uf = UnionFind(n)
        count=n
        for x in logs:
            if not uf.connected(x[1],x[2]):
                uf.union(x[1],x[2])
                count-=1
            if count==1:
                return x[0]
            
        return -1
        

![image.png](attachment:image.png)

![image.png](attachment:image.png)

In [None]:
# Accepted V1
# UnionFind class

from collections import defaultdict

class UnionFind:
    def __init__(self, size):
        self.root = [i for i in range(size)]
        # Use a rank array to record the height of each vertex, i.e., the "rank" of each vertex.
        # The initial "rank" of each vertex is 1, because each of them is
        # a standalone vertex with no connection to other vertices.
        self.rank = [1] * size
        self.count = size

    # The find function here is the same as that in the disjoint set with path compression.
    def find(self, x):
        if x == self.root[x]:
            return x
        self.root[x] = self.find(self.root[x])
        return self.root[x]

    # The union function with union by rank
    def union(self, x, y):
        rootX = self.find(x)
        rootY = self.find(y)
        if rootX != rootY:
            if self.rank[rootX] > self.rank[rootY]:
                self.root[rootY] = rootX
            elif self.rank[rootX] < self.rank[rootY]:
                self.root[rootX] = rootY
            else:
                self.root[rootY] = rootX
                self.rank[rootX] += 1
            self.count -= 1

    # connected function of the “disjoint set”
    def connected(self, x, y):
        return self.find(x) == self.find(y)

            
    def getCount(self):
        return self.count


class Solution:
    def smallestStringWithSwaps(self, s: str, pairs: List[List[int]]) -> str:
        uf = UnionFind(len(s))
        for x in pairs:            
            uf.union(x[0],x[1])

        group_dict = defaultdict(list)
        for x in range(len(uf.root)):
            rootX = uf.find(x) 
            group_dict[rootX].append(x)
            
        result = ['']*len(s)
        for k,x_list in group_dict.items():
            x_list = sorted(x_list)
            value_list = [s[x] for x in x_list]
            value_list = sorted(value_list)
            for i,x in enumerate(x_list):
                result[x] = value_list[i]
            
        return ''.join(result)
        


In [None]:
# Evaluate Division

![image.png](attachment:image.png)

![image.png](attachment:image.png)

In [None]:
# Failed v1

from collections import defaultdict

class UnionFind:
    def __init__(self, size):
        self.root = [i for i in range(size)]
        self.rank = [1] * size
        self.count = size

    # The find function here is the same as that in the disjoint set with path compression.
    def find(self, x):
        history = []
        while x != self.root[x]:
            x = self.root[x]
        return x

    # The union function with union by rank
    def union(self, x, y):
        rootX = self.find(x)
        rootY = self.find(y)
        if rootX != rootY:
            if self.rank[rootX] > self.rank[rootY]:
                self.root[rootY] = rootX
            elif self.rank[rootX] < self.rank[rootY]:
                self.root[rootX] = rootY
            else:
                self.root[rootY] = rootX
                self.rank[rootX] += 1
            self.count -= 1

    # connected function of the “disjoint set”
    def connected(self, x, y):
        return self.find(x) == self.find(y)

    def getCount(self):
        return self.count


class Solution:
    def calcEquation(self, equations: List[List[str]], values: List[float], queries: List[List[str]]) -> List[float]:
        
        unit_set = set()    
        for item in equations:
            unit_set.add(item[0])
            unit_set.add(item[1])
        nodes = sorted(list(unit_set))
        edges = [[nodes.index(x[0]), nodes.index(x[1])] for x in equations]
        uf = UnionFind(len(nodes))
        for x in edges:            
            uf.union(x[0],x[1])
            
        print('Nodes:', nodes)
        print("Edges:", edges)
        print('Root:', uf.root)
        for query in queries:
            query_t = [nodes.index(query[0]), nodes.index(query[1])]
            print(query)
            print(query_t)
            break

        return 
        


### Approach 1: Path Search in Graph

In [None]:
class Solution:
    def calcEquation(self, equations: List[List[str]], values: List[float], queries: List[List[str]]) -> List[float]:

        graph = defaultdict(defaultdict)

        def backtrack_evaluate(curr_node, target_node, acc_product, visited):
            visited.add(curr_node)
            ret = -1.0
            neighbors = graph[curr_node]
            if target_node in neighbors:
                ret = acc_product * neighbors[target_node]
            else:
                for neighbor, value in neighbors.items():
                    if neighbor in visited:
                        continue
                    ret = backtrack_evaluate(
                        neighbor, target_node, acc_product * value, visited)
                    if ret != -1.0:
                        break
            visited.remove(curr_node)
            return ret

        # Step 1). build the graph from the equations
        for (dividend, divisor), value in zip(equations, values):
            # add nodes and two edges into the graph
            graph[dividend][divisor] = value
            graph[divisor][dividend] = 1 / value

        # Step 2). Evaluate each query via backtracking (DFS)
        #  by verifying if there exists a path from dividend to divisor
        results = []
        for dividend, divisor in queries:
            if dividend not in graph or divisor not in graph:
                # case 1): either node does not exist
                ret = -1.0
            elif dividend == divisor:
                # case 2): origin and destination are the same node
                ret = 1.0
            else:
                visited = set()
                ret = backtrack_evaluate(dividend, divisor, 1, visited)
            results.append(ret)

        return results

### Approach 2: Union-Find with Weights

In [None]:
class Solution:
    def calcEquation(self, equations: List[List[str]], values: List[float], queries: List[List[str]]) -> List[float]:

        gid_weight = {}

        def find(node_id):
            if node_id not in gid_weight:
                gid_weight[node_id] = (node_id, 1)
            group_id, node_weight = gid_weight[node_id]
            # The above statements are equivalent to the following one
            #group_id, node_weight = gid_weight.setdefault(node_id, (node_id, 1))

            if group_id != node_id:
                # found inconsistency, trigger chain update
                new_group_id, group_weight = find(group_id)
                gid_weight[node_id] = \
                    (new_group_id, node_weight * group_weight)
            return gid_weight[node_id]

        def union(dividend, divisor, value):
            dividend_gid, dividend_weight = find(dividend)
            divisor_gid, divisor_weight = find(divisor)
            if dividend_gid != divisor_gid:
                # merge the two groups together,
                # by attaching the dividend group to the one of divisor
                gid_weight[dividend_gid] = \
                    (divisor_gid, divisor_weight * value / dividend_weight)

        # Step 1). build the union groups
        for (dividend, divisor), value in zip(equations, values):
            union(dividend, divisor, value)

        results = []
        # Step 2). run the evaluation, with "lazy" updates in find() function
        for (dividend, divisor) in queries:
            if dividend not in gid_weight or divisor not in gid_weight:
                # case 1). at least one variable did not appear before
                results.append(-1.0)
            else:
                dividend_gid, dividend_weight = find(dividend)
                divisor_gid, divisor_weight = find(divisor)
                if dividend_gid != divisor_gid:
                    # case 2). the variables do not belong to the same chain/group
                    results.append(-1.0)
                else:
                    # case 3). there is a chain/path between the variables
                    results.append(dividend_weight / divisor_weight)
        return results

### Checker Later
https://leetcode.com/explore/learn/card/graph/618/disjoint-set/3914/

In [9]:
# Optimize Water Distribution in a Village

![image.png](attachment:image.png)

![image.png](attachment:image.png)

### Check Later
https://leetcode.com/explore/learn/card/graph/618/disjoint-set/3916/

In [None]:
# Draft v1
from collections import deque
class Node:
    def __init__(self, val):
        self.val = val
        self.well_cost = None
        self.pipes = dict()
        
def packupNode(n, well_cost, pipes):
    # Initial Node
    node_dict = {i:Node(i) for i in range(1,n+1)}
    
    for i in range(len(well_cost)):
        node_dict[i+1].well_cost = well_cost[i]
        
    # Add Edges  
    for edge in pipes:
        node_a = node_dict[edge[0]]
        node_b = node_dict[edge[1]]
        cost = edge[2]

        node_b.pipes[node_a.val] = cost
        node_a.pipes[node_b.val] = cost 

    return node_dict



def checkAround(node, node_dict):
    stack = deque()
    group_set = set()
    stack.append(node)
    
    while(stack):
        cur = stack.pop()
        group_set.add(cur)
        for pipe, pipe_cost in cur.pipes.items():
            adj_node = node_dict[pipe]
            if not adj_node in group_set:
                if pipe_cost <= adj_node.well_cost:
                    stack.append(adj_node)    
    return group_set


def getGroupCost(group_set):
    # pipe cost
    
    
    
    
    # node cost
    
    pass

class Solution:
    def minCostToSupplyWater(self, n: int, wells: List[int], pipes: List[List[int]]) -> int:
        
        # Base
        checked_set = set()
        node_dict = packupNode(n, wells, pipes)
        node_list = list(node_dict.values())
        group_list = []
        
        #  
        while(node_list):
            cur = node_list.pop()
            print('cur:', cur.val)
            if cur in checked_set:
                continue
            
            group_set = checkAround(cur, node_dict)
            checked_set.update(group_set)
            group_list.append(group_set)
            
            # print("group_set:", [x.val for x in group_set])
            # break
            
        for group_set in group_list:
            print()
            print("group_set:", [x.val for x in group_set])
            
        # All Cost
        cost
            
        return
        
        
        
        
        

In [10]:
import random
random.choice([1,2,3])

1

In [11]:
group_set = set([1,2,3])
random.choice(group_set)

TypeError: 'set' object is not subscriptable

In [12]:
group_set = set([1,2,3])
next(iter(group_set))

1

In [None]:
# Failed v1

from collections import deque

class Node:
    def __init__(self, val):
        self.val = val
        self.well_cost = None
        self.pipes = dict()
        
def packupNode(n, well_cost, pipes):
    # Initial Node
    node_dict = {i:Node(i) for i in range(1,n+1)}
    
    for i in range(len(well_cost)):
        node_dict[i+1].well_cost = well_cost[i]
        
    # Add Edges  
    for edge in pipes:
        node_a = node_dict[edge[0]]
        node_b = node_dict[edge[1]]
        cost = edge[2]

        node_b.pipes[node_a.val] = cost
        node_a.pipes[node_b.val] = cost 

    return node_dict



def checkAround(node, node_dict):
    stack = deque()
    group_set = set()
    stack.append(node)
    
    while(stack):
        cur = stack.pop()
        group_set.add(cur)
        for pipe, pipe_cost in cur.pipes.items():
            adj_node = node_dict[pipe]
            if not adj_node in group_set:
                if pipe_cost <= adj_node.well_cost:
                    stack.append(adj_node)    
    return group_set


def getGroupCost(group_set, node_dict):
    
    # pipe cost
    pipe_cost = 0
    checked_set = set()
    stack = deque()
    stack.append(next(iter(group_set)))
    while(stack):
        cur = stack.pop()
        checked_set.add(cur)
        for neighbour,cost in cur.pipes.items():
            neighbour = node_dict[neighbour]
            if (not neighbour in checked_set) & (neighbour in group_set):
                stack.append(neighbour)
                pipe_cost += cost

    
    # node cost
    node_cost = min([x.well_cost for x in group_set])
    
    return pipe_cost+node_cost

class Solution:
    def minCostToSupplyWater(self, n: int, wells: List[int], pipes: List[List[int]]) -> int:
        
        # Base
        checked_set = set()
        node_dict = packupNode(n, wells, pipes)
        node_list = list(node_dict.values())
        group_list = []
        
        #  
        while(node_list):
            cur = node_list.pop()
            print('cur:', cur.val)
            if cur in checked_set:
                continue
            
            group_set = checkAround(cur, node_dict)
            checked_set.update(group_set)
            group_list.append(group_set)
            
            # print("group_set:", [x.val for x in group_set])
            # break
            
        for group_set in group_list:
            print()
            print("group_set:", [x.val for x in group_set])
            
        # All Cost
        all_cost = 0
        
        for group_set in group_list:
            group_cost = getGroupCost(group_set, node_dict)
            print()
            print('group_cost:',group_cost)
            
            all_cost += group_cost
        
            
            
            
        return all_cost
        







# Last Input
6
[4625,65696,86292,68291,37147,7880]
[[2,1,79394],[3,1,45649],[4,1,75810],[5,3,22340],[6,1,6222]]

In [None]:
"""
把你的图内容展示出来


"""

In [14]:
6222+4625+65696+68291+22340+37147

204321

In [None]:
# Failed v2  Node Groups   
from collections import deque

class Node:
    def __init__(self, val):
        self.val = val
        self.well_cost = None
        self.pipes = dict()
        
def packupNode(n, well_cost, pipes):
    # Initial Node
    node_dict = {i:Node(i) for i in range(1,n+1)}
    
    for i in range(len(well_cost)):
        node_dict[i+1].well_cost = well_cost[i]
        
    # Add Edges  
    for edge in pipes:
        node_a = node_dict[edge[0]]
        node_b = node_dict[edge[1]]
        cost = edge[2]

        node_b.pipes[node_a.val] = cost
        node_a.pipes[node_b.val] = cost 

    return node_dict


def checkAround(node, node_dict):
    stack = deque()
    group_set = set()
    stack.append(node)
    
    while(stack):
        cur = stack.pop()
        group_set.add(cur)
        for pipe, pipe_cost in cur.pipes.items():
            adj_node = node_dict[pipe]
            if not adj_node in group_set:
                if pipe_cost <= adj_node.well_cost:
                    stack.append(adj_node)    
    return group_set


# def getGroupCost(group_set, node_dict):
    
#     # pipe cost
#     pipe_cost = 0
#     checked_set = set()
#     stack = deque()
#     stack.append(next(iter(group_set)))
#     while(stack):
#         cur = stack.pop()
#         checked_set.add(cur)
#         for neighbour,cost in cur.pipes.items():
#             neighbour = node_dict[neighbour]
#             if (not neighbour in checked_set) & (neighbour in group_set):
#                 stack.append(neighbour)
#                 pipe_cost += cost

    
#     # node cost
#     node_cost = min([x.well_cost for x in group_set])
    
#     return pipe_cost, node_cost

def oneCheck(group):
    pass
    """
    Change on edge????
    
    """
    return 



def oneCheck(group):





def getGroupCost(group_set, node_dict):
    
    # pipe cost
    pipe_cost = 0
    checked_set = set()
    stack = deque()
    stack.append(next(iter(group_set)))
    while(stack):
        cur = stack.pop()
        checked_set.add(cur)
        for neighbour,cost in cur.pipes.items():
            neighbour = node_dict[neighbour]
            if (not neighbour in checked_set) & (neighbour in group_set):
                stack.append(neighbour)
                pipe_cost += cost

    # node cost
    node_cost = min([x.well_cost for x in group_set])
    
    return pipe_cost, node_cost


"""

检查部分的最优选择

"""

class Solution:
    def minCostToSupplyWater(self, n: int, wells: List[int], pipes: List[List[int]]) -> int:
        
        # Base
        checked_set = set()
        node_dict = packupNode(n, wells, pipes)
        node_list = list(node_dict.values())
        group_list = []
        
        #  
        while(node_list):
            cur = node_list.pop()
            print('cur:', cur.val)
            if cur in checked_set:
                continue
            
            group_set = checkAround(cur, node_dict)
            checked_set.update(group_set)
            group_list.append(group_set)
            
            # print("group_set:", [x.val for x in group_set])
            # break
            
        for group_set in group_list:
            print()
            print("group_set:", [x.val for x in group_set])
            
        # All Cost
        all_cost = 0
        
        for group_set in group_list:
            pipe_cost, node_cost = getGroupCost(group_set, node_dict)
            group_cost = pipe_cost + node_cost
            print()
            print('group_cost:', group_cost)
            
            all_cost += group_cost
              
            
        return all_cost
        
        
        
        
        

In [None]:
# Failed v3   Edge to Core    and   Group as Node
from collections import deque

class Node:
    def __init__(self, val):
        self.val = val
        self.well_cost = None
        self.pipes = dict()
        


class Solution:
    def minCostToSupplyWater(self, n: int, wells: List[int], pipes: List[List[int]]) -> int:
        
        # Base
        checked_set = set()
        node_dict = packupNode(n, wells, pipes)
        node_list = list(node_dict.values())

        
        # Get Edges
        """
            all edges
            min well cost 
            update edge->group dict
            checked_set

        """
        edge_list = [x for x in node_list if len(x.pipes)<=1]
        edge_group_dict = {x.val:{'min_well_cost':x.well_cost,
                                  'pipe_cost':0,
                                  'edge':x,
                                  'excluded_edge':set()
                                 } for x in edge_list}
        stack = deque(edge_list)
        while(stack):
            cur = stack.pop()
            checked_set.add(cur)
            edge_group = edge_group_dict[cur.val]
            
            for neighbour,pipe_cost in cur.pipes:
                neighbour = node_dict[neighbour]
                if pipe_cost>neighbour.well_cost and pipe_cost>cur.well_cost:
                    edge_group['excluded_edge'].add(neighbour)
                else:
                    edge_group['']
        
        
        
        """
        min_well_cost > edge + well_cost
        
        
        this edge is the min edge  and edge        
        
        
        
        """
        
        
        
        
        # Check and ge group
        
        
        
        
        
        return all_cost
        

In [72]:
# Draft v4
# Failed v2  Node Groups   
from collections import deque,Counter

class Node:
    def __init__(self, val):
        self.val = val
        self.well_cost = None
        self.pipes = dict()
        self.info = {
            'connected_min_well_node':{},
            'connected_min_well_cost':{},
            'connected_pipe_cost':{},
            'connected_pipe_set':{}
        }
        



"""

检查部分的最优选择

"""

class Solution:
    
    
    def minCostToSupplyWater(self, n, wells, pipes):
        

        def packupNode(n, well_cost, pipes):
            # Initial Node
            node_dict = {i:Node(i) for i in range(1,n+1)}

            for i in range(len(well_cost)):
                node_dict[i+1].well_cost = well_cost[i]

            # Add Edges  
            for edge in pipes:
                node_a = node_dict[edge[0]]
                node_b = node_dict[edge[1]]
                cost = edge[2]

                node_b.pipes[node_a.val] = cost
                node_a.pipes[node_b.val] = cost 

            return node_dict


        def checkAround(node, node_dict):
            stack = deque()
            group_set = set()
            stack.append(node)

            while(stack):
                cur = stack.pop()
                group_set.add(cur)
                for pipe, pipe_cost in cur.pipes.items():
                    adj_node = node_dict[pipe]
                    if not adj_node in group_set:
                        if pipe_cost <= cur.well_cost or pipe_cost <= adj_node.well_cost:
                            stack.append(adj_node)    
            return group_set


        def getGroupCost(group_set, node_dict):

            """
                Edge Stack
                Fork Stack
            """
            well_status = {}
            pipe_status = {}

            def get_fork(group_set):
                return [x for x in group_set if len(x.pipes)>2] 

            def get_edge(group_set):
                return [x for x in group_set if len(x.pipes)==1]

            t(1)
            fork_list = get_fork(group_set)
            edge_list = get_edge(group_set)

            print('Fork_List:', [x.val for x in fork_list])
            print("Edge_List:", [x.val for x in edge_list])



            # Round1 Edge to Fork (or edge)

            def update_pipe_status(node1,node2,status):
                key = tuple(sorted([node1.val, node2.val]))
                pipe_status[key] = 1



            def update_value(pre,cur,direction):
                if not pre:
                    cur.info['connected_min_well_cost'][direction] = cur.well_cost
                    cur.info['connected_pipe_cost'][direction] = 0
                    cur.info['connected_min_well_node'][direction] = cur.val
                    return

                pipe_cost = cur.pipes[pre.val]
                well_cost = cur.well_cost
                pre_connected_min_well_cost = pre.info['connected_min_well_cost'][direction]
                pre_connected_pipe_cost = pre.info['connected_pipe_cost'][direction]
                pre_connected_min_well_node = pre.info['connected_min_well_node'][direction]
        #         pre_connected_pipe_set = pre.info['connected_pipe_set'][direction]
                cur_connected_min_well_node = cur.info['connected_min_well_node'].get(direction, -1)


                if well_cost<(pre_connected_min_well_cost + pre_connected_pipe_cost):
                    cur_connected_min_well_cost = well_cost
                    cur_connected_pipe_cost = 0
        #             cur_connected_pipe_cost = set()
                    # Disconnect Pipe
                    update_pipe_status(pre, cur, 0)
                    cur_connected_min_well_node = cur.val
                    well_status[cur_connected_min_well_node] = 1

                else:
                    cur_connected_min_well_cost = pre_connected_min_well_cost
                    cur_connected_pipe_cost = pre_connected_pipe_cost + pipe_cost
                    # Connect pipe
                    update_pipe_status(pre, cur, 1)
                    if cur_connected_min_well_node==cur.val:
                        well_status[cur_connected_min_well_node] = 0
                    cur_connected_min_well_node = pre_connected_min_well_node

                cur.info['connected_min_well_cost'][direction] = cur_connected_min_well_cost
                cur.info['connected_pipe_cost'][direction] = cur_connected_pipe_cost
                cur.info['connected_min_well_node'][direction] = cur_connected_min_well_node

            def extend_edge(node, pre=None):
                # Source node as direction
                if pre:
                    direction = pre.val
                else:
                    direction = node.val

                cur = node
                while cur and (not cur in fork_list):
                    print('extend_edge cur:',cur.val if cur else None)
                    print('pre:',pre.val if pre else None)
                    # Update middle node
                    update_value(pre,cur,direction)
                    tmp = cur
                    cur = [node_dict[x] for x in cur.pipes if (not pre) or (x!=pre.val)]
                    cur = cur[0] if cur else None
                    pre = tmp


                # Update fork
                if cur:
        #             pass
                    update_value(pre,cur,direction)
                #cur['in_path_node_status'][cur.val] = 1 

            t(2)
            for edge in edge_list:
                extend_edge(edge)


            # Round2 Fork to Edge (or fork)
            def extend_fork(node):
                base = node
                fork_direction = base.val

                connected_min_well_cost = base.well_cost
                connected_pipe_cost = 0
                for edge_direction in base.pipes:
                    cur_connected_min_well_cost = base.info['connected_min_well_cost'].get(
                                                edge_direction,float('inf'))
                    cur_connected_pipe_cost = base.info['connected_pipe_cost'].get(
                                                edge_direction,float('inf'))

                    if (cur_connected_min_well_cost+cur_connected_pipe_cost) <(
                        connected_min_well_cost+connected_pipe_cost
                    ):
                        connected_min_well_cost = cur_connected_min_well_cost
                        connected_pipe_cost = cur_connected_pipe_cost

                for edge_direction in base.pipes:
                    cur_connected_min_well_cost = base.info['connected_min_well_cost'].get(
                                                edge_direction,float('inf'))
                    cur_connected_pipe_cost = base.info['connected_pipe_cost'].get(
                                                edge_direction,float('inf'))

                    if (cur_connected_min_well_cost+cur_connected_pipe_cost) >(
                        connected_min_well_cost+connected_pipe_cost
                    ):
                        cur = node_dict[edge_direction]
                        extend_edge(cur, pre=base)

            t(3)
            for fork in fork_list:
                extend_fork(fork)
            # Recording Updated edge or fork


            print('Fork_List:', [x.val for x in fork_list])
            print("Edge_List:", [x.val for x in edge_list])


            def clean_off():
                well_edges = [x for x in pipe_status if (x[0] in well_status) or (x[1] in well_status)]
                print("well_edges:",well_edges)
                tmp = []
                for x in well_edges:
                    tmp+=x
                well_edges = tmp
                well_edges = Counter(well_edges)
                print("well_edges:",well_edges)
                well_edges = {k for k,v in well_edges.items() if (not k in well_status) and (v>1)}
                print("well_edges:", well_edges)
                
                for node in well_edges:
                    node = node_dict[node]
                    min_pipe,min_cost = sorted([(pipe,cost) for pipe,cost in node.pipes.items()],key=lambda x:x[1])[0]
                    for pipe,cost in node.pipes.items():
                        if (cost>=min_cost) and (pipe!=min_pipe):
                            key = tuple(sorted([pipe,node.val]))
                            pipe_status[key]=0
                        
                
            
            clean_off()
            
            print()
            for x in node_dict.values():
                print('node:',x.val)
                print('info:',x.info)
                print()


            print("pipe_status:",pipe_status)
            print("well_status:",well_status)
            
            
            
            pipe_cost = sum([pipe_cost_dict[k] for k,v in pipe_status.items() if v])
            node_cost = sum([well_cost_dict[k] for k,v in well_status.items() if v])

            return pipe_cost, node_cost


        
        
        pipe_cost_dict = {tuple(sorted((x[0],x[1]))):x[2] for x in pipes }
        well_cost_dict = {(i+1):x for i,x in enumerate(wells)}
        
        # Base
        checked_set = set()
        node_dict = packupNode(n, wells, pipes)
        node_list = list(node_dict.values())
        group_list = []
        
        #  
        while(node_list):
            cur = node_list.pop()
            print('cur:', cur.val)
            if cur in checked_set:
                continue
            
            group_set = checkAround(cur, node_dict)
            checked_set.update(group_set)
            group_list.append(group_set)
            
            # print("group_set:", [x.val for x in group_set])
            # break
            
        for group_set in group_list:
            print()
            print("group_set:", [x.val for x in group_set])
            
        # All Cost
        all_cost = 0
        
        for group_set in group_list:
            print("group_set:",[x.val for x in group_set])
            pipe_cost, node_cost = getGroupCost(group_set, node_dict)
            group_cost = pipe_cost + node_cost
            print()
            print('group_cost:', group_cost)
            
            all_cost += group_cost
              
            
        return all_cost
        
t = lambda x: print(x)
# n = 3
# wells = [1,2,2]
# pipes = [[1,2,1],[2,3,1]]
        
n = 3
wells = [4600, 86000, 37000]
pipes = [[1,2,45000],[2,3,22000]]
    
    
tmp = Solution()
tmp.minCostToSupplyWater(n, wells, pipes)

cur: 3
cur: 2
cur: 1

group_set: [3, 1, 2]
group_set: [3, 1, 2]
1
Fork_List: []
Edge_List: [3, 1]
2
extend_edge cur: 3
pre: None
extend_edge cur: 2
pre: 3
extend_edge cur: 1
pre: 2
extend_edge cur: 1
pre: None
extend_edge cur: 2
pre: 1
extend_edge cur: 3
pre: 2
3
Fork_List: []
Edge_List: [3, 1]
well_edges: [(2, 3), (1, 2)]
well_edges: Counter({2: 2, 3: 1, 1: 1})
well_edges: {2}

node: 1
info: {'connected_min_well_node': {3: 1, 1: 1}, 'connected_min_well_cost': {3: 4600, 1: 4600}, 'connected_pipe_cost': {3: 0, 1: 0}, 'connected_pipe_set': {}}

node: 2
info: {'connected_min_well_node': {3: 3, 1: 1}, 'connected_min_well_cost': {3: 37000, 1: 4600}, 'connected_pipe_cost': {3: 22000, 1: 45000}, 'connected_pipe_set': {}}

node: 3
info: {'connected_min_well_node': {3: 3, 1: 3}, 'connected_min_well_cost': {3: 37000, 1: 37000}, 'connected_pipe_cost': {3: 0, 1: 0}, 'connected_pipe_set': {}}

pipe_status: {(2, 3): 1, (1, 2): 0}
well_status: {1: 1, 3: 1}

group_cost: 63600


63600

In [None]:
# Failed v5
# Draft v4
# Failed v2  Node Groups   
from collections import deque,Counter

class Node:
    def __init__(self, val):
        self.val = val
        self.well_cost = None
        self.pipes = dict()
        self.info = {
            'connected_min_well_node':{},
            'connected_min_well_cost':{},
            'connected_pipe_cost':{},
            'connected_pipe_set':{}
        }
        



"""

检查部分的最优选择

"""

class Solution:
    
    
    def minCostToSupplyWater(self, n, wells, pipes):
        

        def packupNode(n, well_cost, pipes):
            # Initial Node
            node_dict = {i:Node(i) for i in range(1,n+1)}

            for i in range(len(well_cost)):
                node_dict[i+1].well_cost = well_cost[i]

            # Add Edges  
            for edge in pipes:
                node_a = node_dict[edge[0]]
                node_b = node_dict[edge[1]]
                cost = edge[2]

                node_b.pipes[node_a.val] = cost
                node_a.pipes[node_b.val] = cost 

            return node_dict


        def checkAround(node, node_dict):
            stack = deque()
            group_set = set()
            stack.append(node)

            while(stack):
                cur = stack.pop()
                group_set.add(cur)
                for pipe, pipe_cost in cur.pipes.items():
                    adj_node = node_dict[pipe]
                    if not adj_node in group_set:
                        if pipe_cost <= cur.well_cost or pipe_cost <= adj_node.well_cost:
                            stack.append(adj_node)    
            return group_set


        def getGroupCost(group_set, node_dict):

            """
                Edge Stack
                Fork Stack
            """
            well_status = {}
            pipe_status = {}

            def get_fork(group_set):
                return [x for x in group_set if len(x.pipes)>2] 

            def get_edge(group_set):
                return [x for x in group_set if len(x.pipes)==1]

            fork_list = get_fork(group_set)
            edge_list = get_edge(group_set)

            print('Fork_List:', [x.val for x in fork_list])
            print("Edge_List:", [x.val for x in edge_list])



            # Round1 Edge to Fork (or edge)

            def update_pipe_status(node1,node2,status):
                key = tuple(sorted([node1.val, node2.val]))
                pipe_status[key] = 1



            def update_value(pre,cur,direction):
                if not pre:
                    cur.info['connected_min_well_cost'][direction] = cur.well_cost
                    cur.info['connected_pipe_cost'][direction] = 0
                    cur.info['connected_min_well_node'][direction] = cur.val
                    well_status[cur.val] = 1
                    return

                pipe_cost = cur.pipes[pre.val]
                well_cost = cur.well_cost
                pre_connected_min_well_cost = pre.info['connected_min_well_cost'][direction]
                pre_connected_pipe_cost = pre.info['connected_pipe_cost'][direction]
                pre_connected_min_well_node = pre.info['connected_min_well_node'][direction]
        #         pre_connected_pipe_set = pre.info['connected_pipe_set'][direction]
                cur_connected_min_well_node = cur.info['connected_min_well_node'].get(direction, -1)


                if well_cost<(pre_connected_min_well_cost + pre_connected_pipe_cost):
                    cur_connected_min_well_cost = well_cost
                    cur_connected_pipe_cost = 0
        #             cur_connected_pipe_cost = set()
                    # Disconnect Pipe
                    update_pipe_status(pre, cur, 0)
                    cur_connected_min_well_node = cur.val
                    well_status[cur_connected_min_well_node] = 1
                    print('tata:', well_status)

                else:
                    cur_connected_min_well_cost = pre_connected_min_well_cost
                    cur_connected_pipe_cost = pre_connected_pipe_cost + pipe_cost
                    # Connect pipe
                    update_pipe_status(pre, cur, 1)
                    if cur_connected_min_well_node==cur.val:
                        well_status[cur_connected_min_well_node] = 0
                    cur_connected_min_well_node = pre_connected_min_well_node
                    print('tata2:',cur_connected_min_well_node)
                    
                cur.info['connected_min_well_cost'][direction] = cur_connected_min_well_cost
                cur.info['connected_pipe_cost'][direction] = cur_connected_pipe_cost
                cur.info['connected_min_well_node'][direction] = cur_connected_min_well_node

            def extend_edge(node, pre=None):
                # Source node as direction
                if pre:
                    direction = pre.val
                else:
                    direction = node.val

                cur = node
                while cur and (not cur in fork_list):
                    print('extend_edge cur:',cur.val if cur else None)
                    print('pre:',pre.val if pre else None)
                    # Update middle node
                    update_value(pre,cur,direction)
                    tmp = cur
                    cur = [node_dict[x] for x in cur.pipes if (not pre) or (x!=pre.val)]
                    cur = cur[0] if cur else None
                    pre = tmp


                # Update fork
                if cur:
        #             pass
                    update_value(pre,cur,direction)
                #cur['in_path_node_status'][cur.val] = 1 

            for edge in edge_list:
                extend_edge(edge)


            # Round2 Fork to Edge (or fork)
            def extend_fork(node):
                base = node
                fork_direction = base.val

                connected_min_well_cost = base.well_cost
                connected_pipe_cost = 0
                for edge_direction in base.pipes:
                    cur_connected_min_well_cost = base.info['connected_min_well_cost'].get(
                                                edge_direction,float('inf'))
                    cur_connected_pipe_cost = base.info['connected_pipe_cost'].get(
                                                edge_direction,float('inf'))

                    if (cur_connected_min_well_cost+cur_connected_pipe_cost) <(
                        connected_min_well_cost+connected_pipe_cost
                    ):
                        connected_min_well_cost = cur_connected_min_well_cost
                        connected_pipe_cost = cur_connected_pipe_cost

                for edge_direction in base.pipes:
                    cur_connected_min_well_cost = base.info['connected_min_well_cost'].get(
                                                edge_direction,float('inf'))
                    cur_connected_pipe_cost = base.info['connected_pipe_cost'].get(
                                                edge_direction,float('inf'))

                    if (cur_connected_min_well_cost+cur_connected_pipe_cost) >(
                        connected_min_well_cost+connected_pipe_cost
                    ):
                        cur = node_dict[edge_direction]
                        extend_edge(cur, pre=base)


            for fork in fork_list:
                extend_fork(fork)
            # Recording Updated edge or fork


            print('Fork_List:', [x.val for x in fork_list])
            print("Edge_List:", [x.val for x in edge_list])


            def clean_off():
                well_edges = [x for x in pipe_status if (x[0] in well_status) or (x[1] in well_status)]
                print("well_edges:",well_edges)
                tmp = []
                for x in well_edges:
                    tmp+=x
                well_edges = tmp
                well_edges = Counter(well_edges)
                print("well_edges:",well_edges)
                well_edges = {k for k,v in well_edges.items() if (not k in well_status) and (v>1)}
                print("well_edges:", well_edges)
                
                for node in well_edges:
                    node = node_dict[node]
                    min_pipe,min_cost = sorted([(pipe,cost) for pipe,cost in node.pipes.items()],key=lambda x:x[1])[0]
                    for pipe,cost in node.pipes.items():
                        if (cost>=min_cost) and (pipe!=min_pipe):
                            key = tuple(sorted([pipe,node.val]))
                            pipe_status[key]=0
                        
                
            
            clean_off()
            
            
            pipe_status = {k:v for k,v in pipe_status.items() if v}
            well_status = {k:v for k,v in well_status.items() if v}
            
            
            if len(well_status)==n and len(pipe_status)>0:
                pipe_status = {}
            
            
            print()
            for x in node_dict.values():
                print('node:',x.val)
                print('info:',x.info)
                print()


            print("pipe_status:",pipe_status)
            print("well_status:",well_status)
            
            
            
            pipe_cost = sum([pipe_cost_dict[k] for k,v in pipe_status.items() if v])
            node_cost = sum([well_cost_dict[k] for k,v in well_status.items() if v])

            print('pipe_cost:',pipe_cost)
            print('node_cost:',node_cost)
            return pipe_cost, node_cost


        
        
        pipe_cost_dict = {tuple(sorted((x[0],x[1]))):x[2] for x in pipes }
        well_cost_dict = {(i+1):x for i,x in enumerate(wells)}
        
        # Base
        checked_set = set()
        node_dict = packupNode(n, wells, pipes)
        node_list = list(node_dict.values())

        pipe_cost, node_cost = getGroupCost(node_list, node_dict)
        group_cost = pipe_cost + node_cost
            
        return group_cost
        
# t = lambda x: print(x)
# # n = 3
# # wells = [1,2,2]
# # pipes = [[1,2,1],[2,3,1]]
        
# n = 3
# wells = [4600, 86000, 37000]
# pipes = [[1,2,45000],[2,3,22000]]
    
    
# tmp = Solution()
# tmp.minCostToSupplyWater(n, wells, pipes)

In [None]:
# Failed v6
# Draft v4
# Failed v2  Node Groups   
from collections import deque,Counter

class Node:
    def __init__(self, val):
        self.val = val
        self.well_cost = None
        self.pipes = dict()
        self.info = {
            'connected_min_well_node':{},
            'connected_min_well_cost':{},
            'connected_pipe_cost':{},
            'connected_pipe_set':{}
        }
        



"""

检查部分的最优选择

"""

class Solution:
    
    
    def minCostToSupplyWater(self, n, wells, pipes):
        

        def packupNode(n, well_cost, pipes):
            # Initial Node
            node_dict = {i:Node(i) for i in range(1,n+1)}

            for i in range(len(well_cost)):
                node_dict[i+1].well_cost = well_cost[i]

            # Add Edges  
            for edge in pipes:
                node_a = node_dict[edge[0]]
                node_b = node_dict[edge[1]]
                cost = edge[2]

                node_b.pipes[node_a.val] = cost
                node_a.pipes[node_b.val] = cost 

            return node_dict


        def checkAround(node, node_dict):
            stack = deque()
            group_set = set()
            stack.append(node)

            while(stack):
                cur = stack.pop()
                group_set.add(cur)
                for pipe, pipe_cost in cur.pipes.items():
                    adj_node = node_dict[pipe]
                    if not adj_node in group_set:
                        if pipe_cost <= cur.well_cost or pipe_cost <= adj_node.well_cost:
                            stack.append(adj_node)    
            return group_set


        def getGroupCost(group_set, node_dict):

            """
                Edge Stack
                Fork Stack
            """
            well_status = {}
            pipe_status = {}

            def get_fork(group_set):
                return [x for x in group_set if len(x.pipes)>2] 

            def get_edge(group_set):
                return [x for x in group_set if len(x.pipes)<=1]

            fork_list = get_fork(group_set)
            edge_list = get_edge(group_set)

            print('Fork_List:', [x.val for x in fork_list])
            print("Edge_List:", [x.val for x in edge_list])



            # Round1 Edge to Fork (or edge)

            def update_pipe_status(node1,node2,status):
                key = tuple(sorted([node1.val, node2.val]))
                pipe_status[key] = status


            def update_value(pre,cur,direction):           
                if not pre:
                    cur.info['connected_min_well_cost'][direction] = cur.well_cost
                    cur.info['connected_pipe_cost'][direction] = 0
                    cur.info['connected_min_well_node'][direction] = cur.val
                    well_status[cur.val] = 1
                    return

                pipe_cost = cur.pipes[pre.val]
                well_cost = cur.well_cost
                pre_connected_min_well_cost = pre.info['connected_min_well_cost'][direction]
                pre_connected_pipe_cost = pre.info['connected_pipe_cost'][direction]
                pre_connected_min_well_node = pre.info['connected_min_well_node'][direction]
        #         pre_connected_pipe_set = pre.info['connected_pipe_set'][direction]
                cur_connected_min_well_node = cur.info['connected_min_well_node'].get(direction, -1)


                if well_cost<(pre_connected_min_well_cost + pre_connected_pipe_cost):
                    cur_connected_min_well_cost = well_cost
                    cur_connected_pipe_cost = 0
        #             cur_connected_pipe_cost = set()
                    # Disconnect Pipe
                    update_pipe_status(pre, cur, 0)
                    cur_connected_min_well_node = cur.val
                    well_status[cur_connected_min_well_node] = 1
                    print('tata:', well_status)

                else:
                    cur_connected_min_well_cost = pre_connected_min_well_cost
                    cur_connected_pipe_cost = pre_connected_pipe_cost + pipe_cost
                    # Connect pipe
                    update_pipe_status(pre, cur, 1)
                    if cur_connected_min_well_node==cur.val:
                        well_status[cur_connected_min_well_node] = 0
                    cur_connected_min_well_node = pre_connected_min_well_node
                    print('tata2:',cur_connected_min_well_node)
                    
                cur.info['connected_min_well_cost'][direction] = cur_connected_min_well_cost
                cur.info['connected_pipe_cost'][direction] = cur_connected_pipe_cost
                cur.info['connected_min_well_node'][direction] = cur_connected_min_well_node
                
                
                counter_direction=None
                if len(cur.info['connected_min_well_cost'])==2:
                    counter_direction = [x for x in cur.info['connected_min_well_cost'] 
                                                    if x!=direction][0]
                    
                    
                if counter_direction:
                    if ((cur.info['connected_min_well_cost'][direction] 
                        + cur.info['connected_pipe_cost'][direction]
                        + cur.info['connected_pipe_cost'][counter_direction])>
                    (cur.info['connected_min_well_cost'][counter_direction] 
                        + cur.info['connected_pipe_cost'][counter_direction]
                        + cur.info['connected_pipe_cost'][direction])):
                        min_node = cur.info['connected_min_well_node'][direction]
                        well_status[min_node] = 0
                        if cur.info['connected_pipe_cost'][direction]:
                            update_pipe_status(pre, cur, 1)
                        print('cur.info:',cur.info)
                        print('direction:',direction)
                        print('lala:',cur.val, pre.val)
                        print('pipe_status:',pipe_status)
                        return True
                return False
                        

            def extend_edge(node, pre=None):
                # Source node as direction
                if pre:
                    direction = pre.val
                else:
                    direction = node.val

                cur = node
                while cur and (not cur in fork_list):
                    print('extend_edge cur:',cur.val if cur else None)
                    print('pre:',pre.val if pre else None)
                    # Update middle node
                    stop_ind = update_value(pre,cur,direction)
                    if stop_ind:
                        return
                    tmp = cur
                    cur = [node_dict[x] for x in cur.pipes if (not pre) or (x!=pre.val)]
                    cur = cur[0] if cur else None
                    pre = tmp


                # Update fork
                if cur:
        #             pass
                    update_value(pre,cur,direction)
                #cur['in_path_node_status'][cur.val] = 1 

            for edge in edge_list:
                extend_edge(edge)


            # Round2 Fork to Edge (or fork)
            def extend_fork(node):
                base = node
                fork_direction = base.val

                connected_min_well_cost = base.well_cost
                connected_pipe_cost = 0
                for edge_direction in base.pipes:
                    cur_connected_min_well_cost = base.info['connected_min_well_cost'].get(
                                                edge_direction,float('inf'))
                    cur_connected_pipe_cost = base.info['connected_pipe_cost'].get(
                                                edge_direction,float('inf'))

                    if (cur_connected_min_well_cost+cur_connected_pipe_cost) <(
                        connected_min_well_cost+connected_pipe_cost
                    ):
                        connected_min_well_cost = cur_connected_min_well_cost
                        connected_pipe_cost = cur_connected_pipe_cost

                for edge_direction in base.pipes:
                    cur_connected_min_well_cost = base.info['connected_min_well_cost'].get(
                                                edge_direction,float('inf'))
                    cur_connected_pipe_cost = base.info['connected_pipe_cost'].get(
                                                edge_direction,float('inf'))

                    if (cur_connected_min_well_cost+cur_connected_pipe_cost) >(
                        connected_min_well_cost+connected_pipe_cost
                    ):
                        cur = node_dict[edge_direction]
                        extend_edge(cur, pre=base)


            for fork in fork_list:
                extend_fork(fork)
            # Recording Updated edge or fork


            print('Fork_List:', [x.val for x in fork_list])
            print("Edge_List:", [x.val for x in edge_list])


            def clean_off():
                well_edges = [x for x in pipe_status if (x[0] in well_status) or (x[1] in well_status)]
                print("well_edges:",well_edges)
                tmp = []
                for x in well_edges:
                    tmp+=x
                well_edges = tmp
                well_edges = Counter(well_edges)
                print("well_edges:",well_edges)
                well_edges = {k for k,v in well_edges.items() if (not k in well_status) and (v>1)}
                print("well_edges:", well_edges)
                
                for node in well_edges:
                    node = node_dict[node]
                    min_pipe,min_cost = sorted([(pipe,cost) for pipe,cost 
                                                in node.pipes.items()],key=lambda x:x[1])[0]
                    for pipe,cost in node.pipes.items():
                        if (cost>=min_cost) and (pipe!=min_pipe):
                            key = tuple(sorted([pipe,node.val]))
                            pipe_status[key]=0
                        
                
            
            # clean_off()
            
            
            pipe_status = {k:v for k,v in pipe_status.items() if v}
            well_status = {k:v for k,v in well_status.items() if v}
            
            
            if len(well_status)==n and len(pipe_status)>0:
                pipe_status = {}
            
            
            print()
            for x in node_dict.values():
                print('node:',x.val)
                print('info:',x.info)
                print()


            print("pipe_status:",pipe_status)
            print("well_status:",well_status)
            
            
            
            pipe_cost = sum([pipe_cost_dict[k] for k,v in pipe_status.items() if v])
            node_cost = sum([well_cost_dict[k] for k,v in well_status.items() if v])

            print('pipe_cost:',pipe_cost)
            print('node_cost:',node_cost)
            return pipe_cost, node_cost


        
        
        pipe_cost_dict = {tuple(sorted((x[0],x[1]))):x[2] for x in pipes }
        well_cost_dict = {(i+1):x for i,x in enumerate(wells)}
        
        # Base
        checked_set = set()
        node_dict = packupNode(n, wells, pipes)
        node_list = list(node_dict.values())

        pipe_cost, node_cost = getGroupCost(node_list, node_dict)
        group_cost = pipe_cost + node_cost
            
        return group_cost
        
# t = lambda x: print(x)
# # n = 3
# # wells = [1,2,2]
# # pipes = [[1,2,1],[2,3,1]]
        
# n = 3
# wells = [4600, 86000, 37000]
# pipes = [[1,2,45000],[2,3,22000]]
    
    
# tmp = Solution()
# tmp.minCostToSupplyWater(n, wells, pipes)

In [None]:
# Failed v7   
from collections import deque,Counter

class Node:
    def __init__(self, val):
        self.val = val
        self.well_cost = None
        self.pipes = dict()
        self.info = {
            'connected_min_well_node':{},
            'connected_min_well_cost':{},
            'connected_pipe_cost':{},
            'connected_pipe_set':{}
        }
        



"""

检查部分的最优选择

"""

class Solution:
    
    
    def minCostToSupplyWater(self, n, wells, pipes):
        

        def packupNode(n, well_cost, pipes):
            # Initial Node
            node_dict = {i:Node(i) for i in range(1,n+1)}

            for i in range(len(well_cost)):
                node_dict[i+1].well_cost = well_cost[i]

            # Add Edges  
            for edge in pipes:
                node_a = node_dict[edge[0]]
                node_b = node_dict[edge[1]]
                cost = edge[2]

                node_b.pipes[node_a.val] = cost
                node_a.pipes[node_b.val] = cost 

            return node_dict


        def checkAround(node, node_dict):
            stack = deque()
            group_set = set()
            stack.append(node)

            while(stack):
                cur = stack.pop()
                group_set.add(cur)
                for pipe, pipe_cost in cur.pipes.items():
                    adj_node = node_dict[pipe]
                    if not adj_node in group_set:
                        if pipe_cost <= cur.well_cost or pipe_cost <= adj_node.well_cost:
                            stack.append(adj_node)    
            return group_set


        def getGroupCost(group_set, node_dict):

            """
                Edge Stack
                Fork Stack
            """
            well_status = {}
            pipe_status = {}

            def get_fork(group_set):
                return [x for x in group_set if len(x.pipes)>2] 

            def get_edge(group_set):
                return [x for x in group_set if len(x.pipes)<=1]

            fork_list = get_fork(group_set)
            edge_list = get_edge(group_set)

            print('Fork_List:', [x.val for x in fork_list])
            print("Edge_List:", [x.val for x in edge_list])



            # Round1 Edge to Fork (or edge)

            def update_pipe_status(node1,node2,status):
                key = tuple(sorted([node1.val, node2.val]))
                pipe_status[key] = status


            def update_value(pre,cur,direction):           
                if not pre:
                    cur.info['connected_min_well_cost'][direction] = cur.well_cost
                    cur.info['connected_pipe_cost'][direction] = 0
                    cur.info['connected_min_well_node'][direction] = cur.val
                    well_status[cur.val] = 1
                    return

                pipe_cost = cur.pipes[pre.val]
                well_cost = cur.well_cost
                
                print('tata3 cur:',cur.val)

                pre_connected_min_well_cost = pre.info['connected_min_well_cost'].get(direction, float('inf'))
                pre_connected_pipe_cost = pre.info['connected_pipe_cost'].get(direction, float('inf'))
                pre_connected_min_well_node = pre.info['connected_min_well_node'].get(direction, float('inf'))
        #         pre_connected_pipe_set = pre.info['connected_pipe_set'][direction]
                cur_connected_min_well_node = cur.info['connected_min_well_node'].get(direction, -1)


                if well_cost<(pre_connected_min_well_cost + pre_connected_pipe_cost):
                    cur_connected_min_well_cost = well_cost
                    cur_connected_pipe_cost = 0
        #             cur_connected_pipe_cost = set()
                    # Disconnect Pipe
                    update_pipe_status(pre, cur, 0)
                    cur_connected_min_well_node = cur.val
                    well_status[cur_connected_min_well_node] = 1
                    print('tata well_status:', well_status)
                    print('tata pipe_status:', pipe_status)

                else:
                    cur_connected_min_well_cost = pre_connected_min_well_cost
                    cur_connected_pipe_cost = pre_connected_pipe_cost + pipe_cost
                    # Connect pipe
                    update_pipe_status(pre, cur, 1)
                    if cur_connected_min_well_node==cur.val:
                        well_status[cur_connected_min_well_node] = 0
                    cur_connected_min_well_node = pre_connected_min_well_node
                    print('tata2:',cur_connected_min_well_node)
                    
                cur.info['connected_min_well_cost'][direction] = cur_connected_min_well_cost
                cur.info['connected_pipe_cost'][direction] = cur_connected_pipe_cost
                cur.info['connected_min_well_node'][direction] = cur_connected_min_well_node
                
                
                counter_direction=None
                if len(cur.info['connected_min_well_cost'])==2:
                    counter_direction = [x for x in cur.info['connected_min_well_cost'] 
                                                    if x!=direction][0]
                    
                    
                if counter_direction:
                    if ((cur.info['connected_min_well_cost'][direction] 
                        + cur.info['connected_pipe_cost'][direction]
                        + cur.info['connected_pipe_cost'][counter_direction])>
                    (cur.info['connected_min_well_cost'][counter_direction] 
                        + cur.info['connected_pipe_cost'][counter_direction]
                        + cur.info['connected_pipe_cost'][direction])):
                        min_node = cur.info['connected_min_well_node'][direction]
                        well_status[min_node] = 0
                        if cur.info['connected_pipe_cost'][direction]:
                            update_pipe_status(pre, cur, 1)
                        print('cur.info:',cur.info)
                        print('direction:',direction)
                        print('lala:',cur.val, pre.val)
                        print('pipe_status:',pipe_status)
                        return True
                return False
                        

            def extend_edge(node, pre=None):
                # Source node as direction
                if pre:
                    direction = pre.val
                else:
                    direction = node.val

                cur = node
                while cur and (not cur in fork_list):
                    print('extend_edge cur:',cur.val if cur else None)
                    print('pre:',pre.val if pre else None)
                    # Update middle node
                    stop_ind = update_value(pre,cur,direction)
                    if stop_ind:
                        return
                    tmp = cur
                    cur = [node_dict[x] for x in cur.pipes if (not pre) or (x!=pre.val)]
                    cur = cur[0] if cur else None
                    pre = tmp


                # Update fork
                if cur:
        #             pass
                    update_value(pre,cur,direction)
                #cur['in_path_node_status'][cur.val] = 1 

            for edge in edge_list:
                extend_edge(edge)


            # Round2 Fork to Edge (or fork)
            def extend_fork(node):
                base = node
                fork_direction = base.val

                connected_min_well_cost = base.well_cost
                connected_pipe_cost = 0
                for edge_direction in base.pipes:
                    cur_connected_min_well_cost = base.info['connected_min_well_cost'].get(
                                                edge_direction,float('inf'))
                    cur_connected_pipe_cost = base.info['connected_pipe_cost'].get(
                                                edge_direction,float('inf'))

                    if (cur_connected_min_well_cost+cur_connected_pipe_cost) <(
                        connected_min_well_cost+connected_pipe_cost
                    ):
                        connected_min_well_cost = cur_connected_min_well_cost
                        connected_pipe_cost = cur_connected_pipe_cost

                for edge_direction in base.pipes:
                    cur_connected_min_well_cost = base.info['connected_min_well_cost'].get(
                                                edge_direction,float('inf'))
                    cur_connected_pipe_cost = base.info['connected_pipe_cost'].get(
                                                edge_direction,float('inf'))

                    if (cur_connected_min_well_cost+cur_connected_pipe_cost) >(
                        connected_min_well_cost+connected_pipe_cost
                    ):
                        cur = node_dict[edge_direction]
                        extend_edge(cur, pre=base)


            for fork in fork_list:
                extend_fork(fork)
            # Recording Updated edge or fork


            print('Fork_List:', [x.val for x in fork_list])
            print("Edge_List:", [x.val for x in edge_list])


            def clean_off():
                well_edges = [x for x in pipe_status if (x[0] in well_status) or (x[1] in well_status)]
                print("well_edges:",well_edges)
                tmp = []
                for x in well_edges:
                    tmp+=x
                well_edges = tmp
                well_edges = Counter(well_edges)
                print("well_edges:",well_edges)
                well_edges = {k for k,v in well_edges.items() if (not k in well_status) and (v>1)}
                print("well_edges:", well_edges)
                
                for node in well_edges:
                    node = node_dict[node]
                    min_pipe,min_cost = sorted([(pipe,cost) for pipe,cost 
                                                in node.pipes.items()],key=lambda x:x[1])[0]
                    for pipe,cost in node.pipes.items():
                        if (cost>=min_cost) and (pipe!=min_pipe):
                            key = tuple(sorted([pipe,node.val]))
                            pipe_status[key]=0
                        
                
            
            # clean_off()
            
            
            pipe_status = {k:v for k,v in pipe_status.items() if v}
            well_status = {k:v for k,v in well_status.items() if v}
            
            
            if len(well_status)==n and len(pipe_status)>0:
                pipe_status = {}
            
            
            print()
            for x in node_dict.values():
                print('node:',x.val)
                print('info:',x.info)
                print()


            print("pipe_status:",pipe_status)
            print("well_status:",well_status)
            
            
            
            pipe_cost = sum([pipe_cost_dict[k] for k,v in pipe_status.items() if v])
            node_cost = sum([well_cost_dict[k] for k,v in well_status.items() if v])

            print('pipe_cost:',pipe_cost)
            print('node_cost:',node_cost)
            return pipe_cost, node_cost


        
        
        pipe_cost_dict = {tuple(sorted((x[0],x[1]))):x[2] for x in pipes }
        well_cost_dict = {(i+1):x for i,x in enumerate(wells)}
        
        # Base
        checked_set = set()
        node_dict = packupNode(n, wells, pipes)
        node_list = list(node_dict.values())

        pipe_cost, node_cost = getGroupCost(node_list, node_dict)
        group_cost = pipe_cost + node_cost
            
        return group_cost
        
# t = lambda x: print(x)
# # n = 3
# # wells = [1,2,2]
# # pipes = [[1,2,1],[2,3,1]]
        
# n = 3
# wells = [4600, 86000, 37000]
# pipes = [[1,2,45000],[2,3,22000]]
    
    
# tmp = Solution()
# tmp.minCostToSupplyWater(n, wells, pipes)

In [None]:
# Draft v8
from collections import deque,Counter

class Node:
    def __init__(self, val):
        self.val = val
        self.well_cost = None
        self.pipes = dict()
        self.info = {
            'connected_min_well_node':{},
            'connected_min_well_cost':{},
            'connected_pipe_cost':{},
            'connected_pipe_set':{}
        }
        
def packupNode(n, well_cost, pipes):
    # Initial Node
    node_dict = {i:Node(i) for i in range(1,n+1)}

    for i in range(len(well_cost)):
        node_dict[i+1].well_cost = well_cost[i]

    # Add Edges  
    for edge in pipes:
        node_a = node_dict[edge[0]]
        node_b = node_dict[edge[1]]
        cost = edge[2]

        node_b.pipes[node_a.val] = cost
        node_a.pipes[node_b.val] = cost 

    return node_dict
 
def get_fork(group_set):
    return [x for x in group_set if len(x.pipes)>2] 

def get_edge(group_set):
    return [x for x in group_set if len(x.pipes)<=1]



def update_pipe_status(node1,node2,status):
    key = tuple(sorted([node1.val, node2.val]))
    pipe_status[key] = status

    
def update_value(pre,cur,direction):           


def extend_edge(node, pre=None):
    # Source node as direction
    

def extend_fork(node):
    
    
def getGroupCost(group_set, node_dict):
    well_status = {}
    pipe_status = {}
    
    
    pipe_cost = sum([pipe_cost_dict[k] for k,v in pipe_status.items() if v])
    node_cost = sum([well_cost_dict[k] for k,v in well_status.items() if v])
    
    return pipe_cost, node_cost

class Solution:
    
    def minCostToSupplyWater(self, n, wells, pipes):
        
        
        pipe_cost_dict = {tuple(sorted((x[0],x[1]))):x[2] for x in pipes }
        well_cost_dict = {(i+1):x for i,x in enumerate(wells)}
        
        # Base
        node_dict = packupNode(n, wells, pipes)
        node_list = list(node_dict.values())
        

        
        
        
        
        
        
        
        
        
        
        
        
        pipe_cost, node_cost = getGroupCost(node_list, node_dict)
        group_cost = pipe_cost + node_cost
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        return group_cost


In [85]:
import inspect

def foo(*args):
    frame = inspect.currentframe()
    frame = inspect.getouterframes(frame)[1]
    string = inspect.getframeinfo(frame[0]).code_context[0].strip()
    args_base = string[string.find('(') + 1:-1].split(',')
    
    names = []
    for i in args_base:
        if i.find('=') != -1:
            names.append(i.split('=')[1].strip())
        
        else:
            names.append(i)
    
    print(names)
    return names, args

def main():
    e = 1
    c = 2
    foo(e, 1000, c)

main()

['e', ' 1000', ' c']


In [86]:
e = 1
c = 2
names, args = foo(e, 1000, c)

['e', ' 1000', ' c']


In [87]:
names

['e', ' 1000', ' c']

In [88]:
args

(1, 1000, 2)

In [90]:
import inspect

def get_varibale_pairs(*args):
    frame = inspect.currentframe()
    frame = inspect.getouterframes(frame)[1]
    string = inspect.getframeinfo(frame[0]).code_context[0].strip()
    args_base = string[string.find('(') + 1:-1].split(',')
    
    names = []
    for i in args_base:
        if i.find('=') != -1:
            names.append(i.split('=')[1].strip())
        
        else:
            names.append(i)
    
#     print(names)
#     return names, args
    varibale_pairs = {name:arg for name,arg in zip(names,args)}
    return varibale_pairs

e = 1
c = 2
varibale_pairs = get_varibale_pairs(e,1000,c)
varibale_pairs

{'e': 1, '1000': 1000, 'c': 2}

In [None]:
# Failed v8   

import inspect

def print_variables(*args):
    frame = inspect.currentframe()
    frame = inspect.getouterframes(frame)[1]
    string = inspect.getframeinfo(frame[0]).code_context[0].strip()
    args_base = string[string.find('(') + 1:-1].split(',')
   
    names = []
    for i in args_base:
        if i.find('=') != -1:
            names.append(i.split('=')[1].strip())
       
        else:
            names.append(i)

    variable_dict = {name:arg for name,arg in zip(names,args)}
    
    ind = names[0]
    print(f'\ntata{ind} start')
    for k,v in variable_dict.items():
        if k==ind:
            continue
        print(f'{k}: {v}')

    print(f'tata{ind} end\n')
    
pp = print_variables




from collections import deque,Counter

class Node:
    def __init__(self, val):
        self.val = val
        self.well_cost = None
        self.pipes = dict()
        self.info = {
            'connected_min_well_node':{},
            'connected_min_well_cost':{},
            'connected_pipe_cost':{},
            'connected_pipe_set':{}
        }
        

"""

检查部分的最优选择

"""

class Solution:
    
    
    def minCostToSupplyWater(self, n, wells, pipes):
        

        def packupNode(n, well_cost, pipes):
            # Initial Node
            node_dict = {i:Node(i) for i in range(1,n+1)}

            for i in range(len(well_cost)):
                node_dict[i+1].well_cost = well_cost[i]

            # Add Edges  
            for edge in pipes:
                node_a = node_dict[edge[0]]
                node_b = node_dict[edge[1]]
                cost = edge[2]

                node_b.pipes[node_a.val] = cost
                node_a.pipes[node_b.val] = cost 

            return node_dict


        def checkAround(node, node_dict):
            stack = deque()
            group_set = set()
            stack.append(node)

            while(stack):
                cur = stack.pop()
                group_set.add(cur)
                for pipe, pipe_cost in cur.pipes.items():
                    adj_node = node_dict[pipe]
                    if not adj_node in group_set:
                        if pipe_cost <= cur.well_cost or pipe_cost <= adj_node.well_cost:
                            stack.append(adj_node)    
            return group_set


        def getGroupCost(group_set, node_dict):

            """
                Edge Stack
                Fork Stack
            """
            well_status = {}
            pipe_status = {}

            def get_fork(group_set):
                return [x for x in group_set if len(x.pipes)>2] 

            def get_edge(group_set):
                return [x for x in group_set if len(x.pipes)<=1]

            fork_list = get_fork(group_set)
            edge_list = get_edge(group_set)

            print('Fork_List:', [x.val for x in fork_list])
            print("Edge_List:", [x.val for x in edge_list])



            # Round1 Edge to Fork (or edge)

            def update_pipe_status(node1,node2,status):
                key = tuple(sorted([node1.val, node2.val]))
                pipe_status[key] = status

                
            def update_value(pre,cur,direction):           
                if not pre:
                    cur.info['connected_min_well_cost'][direction] = cur.well_cost
                    cur.info['connected_pipe_cost'][direction] = 0
                    cur.info['connected_min_well_node'][direction] = cur.val
                    well_status[cur.val] = 1
                    return

                pipe_cost = cur.pipes[pre.val]
                well_cost = cur.well_cost
                
                print('tata3 cur:',cur.val)

                pre_connected_min_well_cost = pre.info['connected_min_well_cost'].get(direction, float('inf'))
                pre_connected_pipe_cost = pre.info['connected_pipe_cost'].get(direction, float('inf'))
                pre_connected_min_well_node = pre.info['connected_min_well_node'].get(direction, float('inf'))
        #         pre_connected_pipe_set = pre.info['connected_pipe_set'][direction]
                cur_connected_min_well_node = cur.info['connected_min_well_node'].get(direction, -1)


                if well_cost<(pre_connected_min_well_cost + pre_connected_pipe_cost):
                    cur_connected_min_well_cost = well_cost
                    cur_connected_pipe_cost = 0
        #             cur_connected_pipe_cost = set()
                    # Disconnect Pipe
                    update_pipe_status(pre, cur, 0)
                    cur_connected_min_well_node = cur.val
                    well_status[cur_connected_min_well_node] = 1
                    print('tata well_status:', well_status)
                    print('tata pipe_status:', pipe_status)

                else:
                    cur_connected_min_well_cost = pre_connected_min_well_cost
                    cur_connected_pipe_cost = pre_connected_pipe_cost + pipe_cost
                    # Connect pipe
                    update_pipe_status(pre, cur, 1)
                    if cur_connected_min_well_node==cur.val:
                        well_status[cur_connected_min_well_node] = 0
                    cur_connected_min_well_node = pre_connected_min_well_node
                    print('tata2:',cur_connected_min_well_node)
                    
                cur.info['connected_min_well_cost'][direction] = cur_connected_min_well_cost
                cur.info['connected_pipe_cost'][direction] = cur_connected_pipe_cost
                cur.info['connected_min_well_node'][direction] = cur_connected_min_well_node
                
                
                print('tata5 cur.info:',cur.info)
                
                
                counter_direction=None
                if len(cur.info['connected_min_well_cost'])==2:
                    counter_direction = [x for x in cur.info['connected_min_well_cost'] 
                                                    if x!=direction][0]
                    
                    
                if counter_direction:
                    if ((cur.info['connected_min_well_cost'][direction] 
                        + cur.info['connected_pipe_cost'][direction]
                        + cur.info['connected_pipe_cost'][counter_direction])>
                    (cur.info['connected_min_well_cost'][counter_direction] 
                        + cur.info['connected_pipe_cost'][counter_direction]
                        + cur.info['connected_pipe_cost'][direction])):
                        min_node = cur.info['connected_min_well_node'][direction]
                        well_status[min_node] = 0
                        if cur.info['connected_pipe_cost'][direction]:
                            update_pipe_status(pre, cur, 1)
                        print('cur.info:',cur.info)
                        print('direction:',direction)
                        print('lala:',cur.val, pre.val)
                        print('pipe_status:',pipe_status)
                        return True
                return False
                        

            def extend_edge(node, pre=None):
                # Source node as direction
                if pre:
                    direction = pre.val
                else:
                    direction = node.val

                cur = node
                while cur and (not cur in fork_list):
                    print('extend_edge cur:',cur.val if cur else None)
                    print('pre:',pre.val if pre else None)
                    # Update middle node
                    stop_ind = update_value(pre,cur,direction)
                    if stop_ind:
                        return
                    tmp = cur
                    cur = [node_dict[x] for x in cur.pipes if (not pre) or (x!=pre.val)]
                    cur = cur[0] if cur else None
                    pre = tmp


                # Update fork
                if cur:
        #             pass
                    print('tata4 cur', cur.val)
                    print('tata4 pre', pre.val)
                    update_value(pre,cur,direction)
                    #???????????
                #cur['in_path_node_status'][cur.val] = 1 

            for edge in edge_list:
                extend_edge(edge)


            # Round2 Fork to Edge (or fork)
            def extend_fork(node):
                base = node
                fork_direction = base.val

                connected_min_well_cost = base.well_cost
                connected_pipe_cost = 0
                # ???
                for edge_direction,edge_pipe_cost in base.pipes.items():
                    node = node_dict[edge_direction]
                    cur_connected_min_well_cost = node.info['connected_min_well_cost'].get(
                                                edge_direction,float('inf'))
                    cur_connected_pipe_cost = node.info['connected_pipe_cost'].get(
                                                edge_direction,float('inf')) + edge_pipe_cost

                    if (cur_connected_min_well_cost+cur_connected_pipe_cost) <(
                        connected_min_well_cost+connected_pipe_cost
                    ):
                        connected_min_well_cost = cur_connected_min_well_cost
                        connected_pipe_cost = cur_connected_pipe_cost

                        
                # Update Fork
                base.info['connected_min_well_cost'][fork_direction] = connected_min_well_cost
                base.info['connected_pipe_cost'][fork_direction] = connected_pipe_cost
                
                
                
                print('tata6 base:',base.val)
                print('connected_min_well_cost:', connected_min_well_cost)
                print('connected_pipe_cost:',connected_pipe_cost)
                print('tata6 end')
                
                pp(7,connected_min_well_cost,connected_pipe_cost)
                
                pp(8,base.val, connected_min_well_cost,connected_pipe_cost)
                
                
                for edge_direction, edge_pipe_cost in base.pipes.items():
                    print('edge_direction:', edge_direction)
                    node = node_dict[edge_direction]
                    cur_connected_min_well_cost = node.info['connected_min_well_cost'].get(
                                                edge_direction, float('inf'))
                    cur_connected_pipe_cost = node.info['connected_pipe_cost'].get(
                                                edge_direction, float('inf')) + edge_pipe_cost
                    
                    print('cur_connected_min_well_cost:',cur_connected_min_well_cost)
                    print('cur_connected_pipe_cost:',cur_connected_pipe_cost)

                    if (cur_connected_min_well_cost+cur_connected_pipe_cost) >(
                        connected_min_well_cost+connected_pipe_cost
                    ):
                        cur = node_dict[edge_direction]
                        extend_edge(cur, pre=base)

                        
            print('\nExtend Fork:\n')
                        

            for fork in fork_list:
                extend_fork(fork)
            # Recording Updated edge or fork
            
            
#             for edge in edge_list:
#                 extend_edge(edge)
            
#             for fork in fork_list:
#                 extend_fork(fork)


            print('Fork_List:', [x.val for x in fork_list])
            print("Edge_List:", [x.val for x in edge_list])


            def clean_off():
                well_edges = [x for x in pipe_status if (x[0] in well_status) or (x[1] in well_status)]
                print("well_edges:",well_edges)
                tmp = []
                for x in well_edges:
                    tmp+=x
                well_edges = tmp
                well_edges = Counter(well_edges)
                print("well_edges:",well_edges)
                well_edges = {k for k,v in well_edges.items() if (not k in well_status) and (v>1)}
                print("well_edges:", well_edges)
                
                for node in well_edges:
                    node = node_dict[node]
                    min_pipe,min_cost = sorted([(pipe,cost) for pipe,cost 
                                                in node.pipes.items()],key=lambda x:x[1])[0]
                    for pipe,cost in node.pipes.items():
                        if (cost>=min_cost) and (pipe!=min_pipe):
                            key = tuple(sorted([pipe,node.val]))
                            pipe_status[key]=0
                        
                
            
            # clean_off()
            
            
            pipe_status = {k:v for k,v in pipe_status.items() if v}
            well_status = {k:v for k,v in well_status.items() if v}
            
            
            if len(well_status)==n and len(pipe_status)>0:
                pipe_status = {}
            
            
            print()
            for x in node_dict.values():
                print('node:',x.val)
                print('info:',x.info)
                print()


            print("pipe_status:",pipe_status)
            print("well_status:",well_status)
            
            
            
            pipe_cost = sum([pipe_cost_dict[k] for k,v in pipe_status.items() if v])
            node_cost = sum([well_cost_dict[k] for k,v in well_status.items() if v])

            print('pipe_cost:',pipe_cost)
            print('node_cost:',node_cost)
            return pipe_cost, node_cost


        
        
        pipe_cost_dict = {tuple(sorted((x[0],x[1]))):x[2] for x in pipes }
        well_cost_dict = {(i+1):x for i,x in enumerate(wells)}
        
        # Base
        checked_set = set()
        node_dict = packupNode(n, wells, pipes)
        node_list = list(node_dict.values())

        pipe_cost, node_cost = getGroupCost(node_list, node_dict)
        group_cost = pipe_cost + node_cost
            
        return group_cost
        
# t = lambda x: print(x)
# # n = 3
# # wells = [1,2,2]
# # pipes = [[1,2,1],[2,3,1]]
        
# n = 3
# wells = [4600, 86000, 37000]
# pipes = [[1,2,45000],[2,3,22000]]
    
    
# tmp = Solution()
# tmp.minCostToSupplyWater(n, wells, pipes)

In [None]:
# Failed v9   


import inspect

def print_variables(*args):
    frame = inspect.currentframe()
    frame = inspect.getouterframes(frame)[1]
    string = inspect.getframeinfo(frame[0]).code_context[0].strip()
    args_base = string[string.find('(') + 1:-1].split(',')
   
    names = []
    for i in args_base:
        if i.find('=') != -1:
            names.append(i.split('=')[1].strip())
       
        else:
            names.append(i)

    variable_dict = {name:arg for name,arg in zip(names,args)}
    
    ind = names[0]
    print(f'\ntata{ind} start')
    for k,v in variable_dict.items():
        if k==ind:
            continue
        print(f'{k}: {v}')

    print(f'tata{ind} end\n')
    
pp = print_variables




from collections import deque,Counter

class Node:
    def __init__(self, val):
        self.val = val
        self.well_cost = None
        self.pipes = dict()
        self.info = {
            'connected_min_well_node':{},
            'connected_min_well_cost':{},
            'connected_pipe_cost':{},
            'connected_pipe_set':{}
        }
        

"""

检查部分的最优选择

"""

class Solution:
    
    
    def minCostToSupplyWater(self, n, wells, pipes):
        

        def packupNode(n, well_cost, pipes):
            # Initial Node
            node_dict = {i:Node(i) for i in range(1,n+1)}

            for i in range(len(well_cost)):
                node_dict[i+1].well_cost = well_cost[i]

            # Add Edges  
            for edge in pipes:
                node_a = node_dict[edge[0]]
                node_b = node_dict[edge[1]]
                cost = edge[2]

                node_b.pipes[node_a.val] = cost
                node_a.pipes[node_b.val] = cost 

            return node_dict


        def checkAround(node, node_dict):
            stack = deque()
            group_set = set()
            stack.append(node)

            while(stack):
                cur = stack.pop()
                group_set.add(cur)
                for pipe, pipe_cost in cur.pipes.items():
                    adj_node = node_dict[pipe]
                    if not adj_node in group_set:
                        if pipe_cost <= cur.well_cost or pipe_cost <= adj_node.well_cost:
                            stack.append(adj_node)    
            return group_set


        def getGroupCost(group_set, node_dict):

            """
                Edge Stack
                Fork Stack
            """
            well_status = {}
            pipe_status = {}

            def get_fork(group_set):
                return [x for x in group_set if len(x.pipes)>2] 

            def get_edge(group_set):
                return [x for x in group_set if len(x.pipes)<=1]

            fork_list = get_fork(group_set)
            edge_list = get_edge(group_set)

            print('Fork_List:', [x.val for x in fork_list])
            print("Edge_List:", [x.val for x in edge_list])



            # Round1 Edge to Fork (or edge)

            def update_pipe_status(node1,node2,status):
                key = tuple(sorted([node1.val, node2.val]))
                pipe_status[key] = status

                
            def update_value(pre,cur,direction):           
                if not pre:
                    cur.info['connected_min_well_cost'][direction] = cur.well_cost
                    cur.info['connected_pipe_cost'][direction] = 0
                    cur.info['connected_min_well_node'][direction] = cur.val
                    well_status[cur.val] = 1
                    return

                pipe_cost = cur.pipes[pre.val]
                well_cost = cur.well_cost
                
                print('tata3 cur:',cur.val)

                pre_connected_min_well_cost = pre.info['connected_min_well_cost'].get(direction, float('inf'))
                pre_connected_pipe_cost = pre.info['connected_pipe_cost'].get(direction, float('inf'))
                pre_connected_min_well_node = pre.info['connected_min_well_node'].get(direction, float('inf'))
        #         pre_connected_pipe_set = pre.info['connected_pipe_set'][direction]
                cur_connected_min_well_node = cur.info['connected_min_well_node'].get(direction, -1)


                if well_cost<(pre_connected_min_well_cost + pre_connected_pipe_cost):
                    cur_connected_min_well_cost = well_cost
                    cur_connected_pipe_cost = 0
        #             cur_connected_pipe_cost = set()
                    # Disconnect Pipe
                    update_pipe_status(pre, cur, 0)
                    cur_connected_min_well_node = cur.val
                    well_status[cur_connected_min_well_node] = 1
                    print('tata well_status:', well_status)
                    print('tata pipe_status:', pipe_status)

                else:
                    cur_connected_min_well_cost = pre_connected_min_well_cost
                    cur_connected_pipe_cost = pre_connected_pipe_cost + pipe_cost
                    # Connect pipe
                    update_pipe_status(pre, cur, 1)
                    if cur_connected_min_well_node==cur.val:
                        well_status[cur_connected_min_well_node] = 0
                    cur_connected_min_well_node = pre_connected_min_well_node
                    print('tata2:',cur_connected_min_well_node)
                    
                cur.info['connected_min_well_cost'][direction] = cur_connected_min_well_cost
                cur.info['connected_pipe_cost'][direction] = cur_connected_pipe_cost
                cur.info['connected_min_well_node'][direction] = cur_connected_min_well_node
                
                
                print('tata5 cur.info:',cur.info)
                
                
                counter_direction=None
                if len(cur.info['connected_min_well_cost'])==2:
                    counter_direction = [x for x in cur.info['connected_min_well_cost'] 
                                                    if x!=direction][0]
                    
                    
                if counter_direction:
                    
                    pp(9,direction,counter_direction,cur.info['connected_min_well_cost'][direction] ,cur.info['connected_min_well_cost'][counter_direction],cur.info['connected_pipe_cost'][direction],cur.info['connected_pipe_cost'][counter_direction])
                    
                    
                    
                    total_direction_cost = (cur.info['connected_min_well_cost'][direction] 
                        + cur.info['connected_pipe_cost'][direction]
                        + cur.info['connected_pipe_cost'][counter_direction])
                    total_counter_direction_cost = (cur.info['connected_min_well_cost'][counter_direction] 
                        + cur.info['connected_pipe_cost'][counter_direction]
                        #+ cur.info['connected_pipe_cost'][direction]
                        )
                    
                    
                    pp(9.1,total_direction_cost,total_counter_direction_cost)
                    
                    if total_direction_cost>total_counter_direction_cost:
                        min_node = cur.info['connected_min_well_node'][direction]
                        pp(10,min_node)
                        
                        well_status[min_node] = 0
                        if cur.info['connected_pipe_cost'][direction]:
                            update_pipe_status(pre, cur, 1)
                        print('cur.info:',cur.info)
                        print('direction:',direction)
                        print('lala:',cur.val, pre.val)
                        print('pipe_status:',pipe_status)
                        return True
                return False
                        

            def extend_edge(node, pre=None):
                # Source node as direction
                if pre:
                    direction = pre.val
                else:
                    direction = node.val

                cur = node
                while cur and (not cur in fork_list):
                    print('extend_edge cur:',cur.val if cur else None)
                    print('pre:',pre.val if pre else None)
                    # Update middle node
                    stop_ind = update_value(pre,cur,direction)
                    if stop_ind:
                        return
                    tmp = cur
                    cur = [node_dict[x] for x in cur.pipes if (not pre) or (x!=pre.val)]
                    cur = cur[0] if cur else None
                    pre = tmp


                # Update fork
                if cur:
        #             pass
                    print('tata4 cur', cur.val)
                    print('tata4 pre', pre.val)
                    update_value(pre,cur,direction)
                    #???????????
                #cur['in_path_node_status'][cur.val] = 1 

            for edge in edge_list:
                extend_edge(edge)


            # Round2 Fork to Edge (or fork)
            def extend_fork(node):
                base = node
                fork_direction = base.val

                connected_min_well_cost = base.well_cost
                connected_pipe_cost = 0
                connected_min_well_node = base.val
                
                
                # ???
                for edge_direction,edge_pipe_cost in base.pipes.items():
                    node = node_dict[edge_direction]
                    cur_connected_min_well_cost = node.info['connected_min_well_cost'].get(
                                                edge_direction,float('inf'))
                    cur_connected_pipe_cost = node.info['connected_pipe_cost'].get(
                                                edge_direction,float('inf'))+edge_pipe_cost 
                    cur_connected_min_well_node =  node.info['connected_min_well_cost'].get(
                                                edge_direction,-1)

                    if (cur_connected_min_well_cost+cur_connected_pipe_cost) <(
                        connected_min_well_cost+connected_pipe_cost
                    ):
                        connected_min_well_cost = cur_connected_min_well_cost
                        connected_pipe_cost = cur_connected_pipe_cost
                        connected_min_well_node = cur_connected_min_well_node

                        
                # Update Fork
                base.info['connected_min_well_cost'][fork_direction] = connected_min_well_cost
                base.info['connected_pipe_cost'][fork_direction] = connected_pipe_cost
                base.info['connected_min_well_node'][fork_direction] = connected_min_well_node
                
                
                pp(8,base.val,connected_min_well_cost,connected_pipe_cost,well_status,pipe_status)
                
                
                for edge_direction, edge_pipe_cost in base.pipes.items():
                    print('edge_direction:', edge_direction)
                    node = node_dict[edge_direction]
                    cur_connected_min_well_cost = node.info['connected_min_well_cost'].get(
                                                edge_direction, float('inf'))
                    cur_connected_pipe_cost = node.info['connected_pipe_cost'].get(
                                                edge_direction, float('inf'))
                    cur_connected_min_well_node = node.info['connected_min_well_node'].get(
                                                edge_direction, float('inf'))
                    
                    print('cur_connected_min_well_cost:',cur_connected_min_well_cost)
                    print('cur_connected_pipe_cost:',cur_connected_pipe_cost)

                    if (cur_connected_min_well_cost+cur_connected_pipe_cost)>(
                        connected_min_well_cost+connected_pipe_cost+edge_pipe_cost
                    ):
                        # well_status[connected_min_well_node] = 0
                        cur = node_dict[edge_direction]
                        extend_edge(cur, pre=base)

                        
            print('\nExtend Fork:\n')
                        

            for fork in fork_list:
                extend_fork(fork)
            # Recording Updated edge or fork
            
            
#             for edge in edge_list:
#                 extend_edge(edge)
            
#             for fork in fork_list:
#                 extend_fork(fork)


            print('Fork_List:', [x.val for x in fork_list])
            print("Edge_List:", [x.val for x in edge_list])


            def clean_off():
                well_edges = [x for x in pipe_status if (x[0] in well_status) or (x[1] in well_status)]
                print("well_edges:",well_edges)
                tmp = []
                for x in well_edges:
                    tmp+=x
                well_edges = tmp
                well_edges = Counter(well_edges)
                print("well_edges:",well_edges)
                well_edges = {k for k,v in well_edges.items() if (not k in well_status) and (v>1)}
                print("well_edges:", well_edges)
                
                for node in well_edges:
                    node = node_dict[node]
                    min_pipe,min_cost = sorted([(pipe,cost) for pipe,cost 
                                                in node.pipes.items()],key=lambda x:x[1])[0]
                    for pipe,cost in node.pipes.items():
                        if (cost>=min_cost) and (pipe!=min_pipe):
                            key = tuple(sorted([pipe,node.val]))
                            pipe_status[key]=0
                        
                
            
            # clean_off()
            
            
            pipe_status = {k:v for k,v in pipe_status.items() if v}
            well_status = {k:v for k,v in well_status.items() if v}
            
            
            if len(well_status)==n and len(pipe_status)>0:
                pipe_status = {}
            
            
            print()
            for x in node_dict.values():
                print('node:',x.val)
                print('info:',x.info)
                print()


            print("pipe_status:",pipe_status)
            print("well_status:",well_status)
            
            
            
            pipe_cost = sum([pipe_cost_dict[k] for k,v in pipe_status.items() if v])
            node_cost = sum([well_cost_dict[k] for k,v in well_status.items() if v])

            print('pipe_cost:',pipe_cost)
            print('node_cost:',node_cost)
            return pipe_cost, node_cost


        
        
        pipe_cost_dict = {tuple(sorted((x[0],x[1]))):x[2] for x in pipes }
        well_cost_dict = {(i+1):x for i,x in enumerate(wells)}
        
        # Base
        checked_set = set()
        node_dict = packupNode(n, wells, pipes)
        node_list = list(node_dict.values())

        pipe_cost, node_cost = getGroupCost(node_list, node_dict)
        group_cost = pipe_cost + node_cost
            
        return group_cost
        
# t = lambda x: print(x)
# # n = 3
# # wells = [1,2,2]
# # pipes = [[1,2,1],[2,3,1]]
        
# n = 3
# wells = [4600, 86000, 37000]
# pipes = [[1,2,45000],[2,3,22000]]
    
    
# tmp = Solution()
# tmp.minCostToSupplyWater(n, wells, pipes)

In [None]:
# Failed v10

import inspect

def print_variables(*args):
    frame = inspect.currentframe()
    frame = inspect.getouterframes(frame)[1]
    string = inspect.getframeinfo(frame[0]).code_context[0].strip()
    args_base = string[string.find('(') + 1:-1].split(',')
   
    names = []
    for i in args_base:
        if i.find('=') != -1:
            names.append(i.split('=')[1].strip())
       
        else:
            names.append(i)

    variable_dict = {name:arg for name,arg in zip(names,args)}
    
    ind = names[0]
    print(f'\ntata{ind} start')
    for k,v in variable_dict.items():
        if k==ind:
            continue
        print(f'{k}: {v}')

    print(f'tata{ind} end\n')
    
pp = print_variables




from collections import deque,Counter

class Node:
    def __init__(self, val):
        self.val = val
        self.well_cost = None
        self.pipes = dict()
        self.info = {
            'connected_min_well_node':{},
            'connected_min_well_cost':{},
            'connected_pipe_cost':{},
            'connected_pipe_set':{}
        }
        


class Solution:
    
    
    def minCostToSupplyWater(self, n, wells, pipes):
        

        def packupNode(n, well_cost, pipes):
            # Initial Node
            node_dict = {i:Node(i) for i in range(1,n+1)}

            for i in range(len(well_cost)):
                node_dict[i+1].well_cost = well_cost[i]

            # Add Edges  
            for edge in pipes:
                node_a = node_dict[edge[0]]
                node_b = node_dict[edge[1]]
                cost = edge[2]
                
                if node_a.val in node_b.pipes:
                    if  node_b.pipes[node_a.val]>cost:
                        node_b.pipes[node_a.val] = cost
                        node_a.pipes[node_b.val] = cost   
                
                else:
                    node_b.pipes[node_a.val] = cost
                    node_a.pipes[node_b.val] = cost 

            print('tata10',[(k,v.pipes) for k,v in node_dict.items()])
                    
            return node_dict


        def checkAround(node, node_dict):
            stack = deque()
            group_set = set()
            stack.append(node)

            while(stack):
                cur = stack.pop()
                group_set.add(cur)
                for pipe, pipe_cost in cur.pipes.items():
                    adj_node = node_dict[pipe]
                    if not adj_node in group_set:
                        if pipe_cost <= cur.well_cost or pipe_cost <= adj_node.well_cost:
                            stack.append(adj_node)    
            return group_set


        def getGroupCost(group_set, node_dict):

            """
                Edge Stack
                Fork Stack
            """
            well_status = {}
            pipe_status = {}

            def get_fork(group_set):
                return [x for x in group_set if len(x.pipes)>2] 

            def get_edge(group_set):
                return [x for x in group_set if len(x.pipes)<=1]

            fork_list = get_fork(group_set)
            edge_list = get_edge(group_set)

            print('Fork_List:', [x.val for x in fork_list])
            print("Edge_List:", [x.val for x in edge_list])



            # Round1 Edge to Fork (or edge)

            def update_pipe_status(node1,node2,status):
                key = tuple(sorted([node1.val, node2.val]))
                pipe_status[key] = status

                
            def update_value(pre,cur,direction):           
                if not pre:
                    cur.info['connected_min_well_cost'][direction] = cur.well_cost
                    cur.info['connected_pipe_cost'][direction] = 0
                    cur.info['connected_min_well_node'][direction] = cur.val
                    well_status[cur.val] = 1
                    return

                pipe_cost = cur.pipes[pre.val]
                well_cost = cur.well_cost
                
                print('tata3 cur:',cur.val)

                pre_connected_min_well_cost = pre.info['connected_min_well_cost'].get(direction, float('inf'))
                pre_connected_pipe_cost = pre.info['connected_pipe_cost'].get(direction, float('inf'))
                pre_connected_min_well_node = pre.info['connected_min_well_node'].get(direction, float('inf'))
        #         pre_connected_pipe_set = pre.info['connected_pipe_set'][direction]
                cur_connected_min_well_node = cur.info['connected_min_well_node'].get(direction, -1)


                if well_cost<(pre_connected_min_well_cost + pre_connected_pipe_cost):
                    cur_connected_min_well_cost = well_cost
                    cur_connected_pipe_cost = 0
        #             cur_connected_pipe_cost = set()
                    # Disconnect Pipe
                    update_pipe_status(pre, cur, 0)
                    cur_connected_min_well_node = cur.val
                    well_status[cur_connected_min_well_node] = 1
                    print('tata well_status:', well_status)
                    print('tata pipe_status:', pipe_status)

                else:
                    cur_connected_min_well_cost = pre_connected_min_well_cost
                    cur_connected_pipe_cost = pre_connected_pipe_cost + pipe_cost
                    # Connect pipe
                    update_pipe_status(pre, cur, 1)
                    if cur_connected_min_well_node==cur.val:
                        well_status[cur_connected_min_well_node] = 0
                    cur_connected_min_well_node = pre_connected_min_well_node
                    print('tata2:',cur_connected_min_well_node)
                    
                cur.info['connected_min_well_cost'][direction] = cur_connected_min_well_cost
                cur.info['connected_pipe_cost'][direction] = cur_connected_pipe_cost
                cur.info['connected_min_well_node'][direction] = cur_connected_min_well_node
                
                
                print('tata5 cur.info:',cur.info)
                
                
                counter_direction=None
                if len(cur.info['connected_min_well_cost'])==2:
                    counter_direction = [x for x in cur.info['connected_min_well_cost'] 
                                                    if x!=direction][0]
                    
                    
                if counter_direction:
                    
                    pp(9,direction,counter_direction,cur.val,cur.info['connected_min_well_cost'][direction] ,cur.info['connected_min_well_cost'][counter_direction],cur.info['connected_pipe_cost'][direction],cur.info['connected_pipe_cost'][counter_direction],pipe_status,well_status)
                    
                    
                    
                    total_direction_cost = (cur.info['connected_min_well_cost'][direction] 
                        + cur.info['connected_pipe_cost'][direction]
                        + cur.info['connected_pipe_cost'][counter_direction])
                    total_counter_direction_cost = (cur.info['connected_min_well_cost'][counter_direction] 
                        + cur.info['connected_pipe_cost'][counter_direction]
                        #+ cur.info['connected_pipe_cost'][direction]
                        )
                    
                    
                    pp(9.1,total_direction_cost,total_counter_direction_cost)
                    
                    if total_direction_cost>total_counter_direction_cost:
#                         min_node = cur.info['connected_min_well_node'][direction]
#                         pp(10,min_node)
                        
                        # well_status[min_node] = 0
                        # if cur.info['connected_pipe_cost'][direction]:
                        
                        if pipe_cost> pre.info['connected_min_well_node'][direction]:
                            update_pipe_status(pre, cur, 0)
                        else:
                            update_pipe_status(pre, cur, 1)
                            min_node = pre.info['connected_min_well_node'][direction]
                            well_status[min_node] = 0
                            
                            
                        print('cur.info:',cur.info)
                        print('direction:',direction)
                        print('lala:',cur.val, pre.val)
                        print('pipe_status:',pipe_status)
                        return True
                    
                return False
                        

            def extend_edge(node, pre=None):
                # Source node as direction
                if pre:
                    direction = pre.val
                else:
                    direction = node.val

                cur = node
                while cur and (not cur in fork_list):
                    print('extend_edge cur:',cur.val if cur else None)
                    print('pre:',pre.val if pre else None)
                    # Update middle node
                    stop_ind = update_value(pre,cur,direction)
                    if stop_ind:
                        return
                    tmp = cur
                    cur = [node_dict[x] for x in cur.pipes if (not pre) or (x!=pre.val)]
                    cur = cur[0] if cur else None
                    pre = tmp


                # Update fork
                if cur:
        #             pass
                    print('tata4 cur', cur.val)
                    print('tata4 pre', pre.val)
                    update_value(pre,cur,direction)
                    #???????????
                #cur['in_path_node_status'][cur.val] = 1 

            for edge in edge_list:
                extend_edge(edge)


            # Round2 Fork to Edge (or fork)
            def extend_fork(node):
                base = node
                fork_direction = base.val

                connected_min_well_cost = base.well_cost
                connected_pipe_cost = 0
                connected_min_well_node = base.val
                
                
                # ???
                for edge_direction,edge_pipe_cost in base.pipes.items():
                    node = node_dict[edge_direction]
                    cur_connected_min_well_cost = node.info['connected_min_well_cost'].get(
                                                edge_direction,float('inf'))
                    cur_connected_pipe_cost = node.info['connected_pipe_cost'].get(
                                                edge_direction,float('inf'))+edge_pipe_cost 
                    cur_connected_min_well_node =  node.info['connected_min_well_cost'].get(
                                                edge_direction,-1)

                    if (cur_connected_min_well_cost+cur_connected_pipe_cost) <(
                        connected_min_well_cost+connected_pipe_cost
                    ):
                        connected_min_well_cost = cur_connected_min_well_cost
                        connected_pipe_cost = cur_connected_pipe_cost
                        connected_min_well_node = cur_connected_min_well_node

                        
                # Update Fork
                base.info['connected_min_well_cost'][fork_direction] = connected_min_well_cost
                base.info['connected_pipe_cost'][fork_direction] = connected_pipe_cost
                base.info['connected_min_well_node'][fork_direction] = connected_min_well_node
                
                
                pp(8,base.val,connected_min_well_cost,connected_pipe_cost,well_status,pipe_status)
                
                
                for edge_direction, edge_pipe_cost in base.pipes.items():
                    node = node_dict[edge_direction]
                    cur_connected_min_well_cost = node.info['connected_min_well_cost'].get(
                                                edge_direction, float('inf'))
                    cur_connected_pipe_cost = node.info['connected_pipe_cost'].get(
                                                edge_direction, float('inf'))
                    cur_connected_min_well_node = node.info['connected_min_well_node'].get(
                                                edge_direction,-1)
                    
                    pp(7,node.val,cur_connected_min_well_cost,cur_connected_pipe_cost,connected_min_well_cost,connected_pipe_cost,edge_pipe_cost)
                    
                    
                    if(cur_connected_min_well_cost+cur_connected_pipe_cost+edge_pipe_cost)>(
                       connected_min_well_cost+connected_pipe_cost+edge_pipe_cost
                    ):
                        if edge_pipe_cost<(cur_connected_min_well_cost+cur_connected_pipe_cost):
                            cur = node_dict[edge_direction]
                            update_pipe_status(base, cur, 1)
                            well_status[cur_connected_min_well_node]=0
                    
                    
                    if (cur_connected_min_well_cost+cur_connected_pipe_cost)>(
                        connected_min_well_cost+connected_pipe_cost+edge_pipe_cost
                    ):
                        
                        # well_status[connected_min_well_node] = 0
                        cur = node_dict[edge_direction]
                        pp(8,cur.val)
                        extend_edge(cur, pre=base)

                        
            print('\nExtend Fork:\n')
                        

            for fork in fork_list:
                extend_fork(fork)
            # Recording Updated edge or fork
            
            
#             for edge in edge_list:
#                 extend_edge(edge)
            
#             for fork in fork_list:
#                 extend_fork(fork)


            print('Fork_List:', [x.val for x in fork_list])
            print("Edge_List:", [x.val for x in edge_list])


            def clean_off():
                well_edges = [x for x in pipe_status if (x[0] in well_status) or (x[1] in well_status)]
                print("well_edges:",well_edges)
                tmp = []
                for x in well_edges:
                    tmp+=x
                well_edges = tmp
                well_edges = Counter(well_edges)
                print("well_edges:",well_edges)
                well_edges = {k for k,v in well_edges.items() if (not k in well_status) and (v>1)}
                print("well_edges:", well_edges)
                
                for node in well_edges:
                    node = node_dict[node]
                    min_pipe,min_cost = sorted([(pipe,cost) for pipe,cost 
                                                in node.pipes.items()],key=lambda x:x[1])[0]
                    for pipe,cost in node.pipes.items():
                        if (cost>=min_cost) and (pipe!=min_pipe):
                            key = tuple(sorted([pipe,node.val]))
                            pipe_status[key]=0
                        
                
            
            # clean_off()
            
            
            pipe_status = {k:v for k,v in pipe_status.items() if v}
            well_status = {k:v for k,v in well_status.items() if v}
            
            
            if len(well_status)==n and len(pipe_status)>0:
                pipe_status = {}
            
            
            print()
            for x in node_dict.values():
                print('node:',x.val)
                print('info:',x.info)
                print()


            print("pipe_status:",pipe_status)
            print("well_status:",well_status)
            
            
            
            pipe_cost = sum([pipe_cost_dict[k] for k,v in pipe_status.items() if v])
            node_cost = sum([well_cost_dict[k] for k,v in well_status.items() if v])

            print('pipe_cost:',pipe_cost)
            print('node_cost:',node_cost)
            return pipe_cost, node_cost


        
        
        pipe_cost_dict = {}
        for x in pipes:
            key = tuple(sorted((x[0],x[1])))
            val = x[2]
            
            if key in pipe_cost_dict:
                if pipe_cost_dict[key]>val:
                    pipe_cost_dict[key]=val
            else:
                pipe_cost_dict[key]=val

        well_cost_dict = {(i+1):x for i,x in enumerate(wells)}
        
        # Base
        checked_set = set()
        node_dict = packupNode(n, wells, pipes)
        node_list = list(node_dict.values())

        pipe_cost, node_cost = getGroupCost(node_list, node_dict)
        group_cost = pipe_cost + node_cost
            
        return group_cost
        
# t = lambda x: print(x)
# # n = 3
# # wells = [1,2,2]
# # pipes = [[1,2,1],[2,3,1]]
        
# n = 3
# wells = [4600, 86000, 37000]
# pipes = [[1,2,45000],[2,3,22000]]
    
    
# tmp = Solution()
# tmp.minCostToSupplyWater(n, wells, pipes)