In [26]:
from typing import Optional

# beef up the TreeNode class to add __str__ and __repr__ methods
# makes it way easier to debug and view the tree
class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

    def __str__(self):
        return '{}({})({})'.format(self.val, self.left, self.right)
    
    def __repr__(self):
        return self.__str__()
    
class Solution:
    def maxPathSum(self, root: Optional[TreeNode]) -> int:
        # stores the max path for aborted paths
        max_v = float('-inf')
        def helper(node: Optional[TreeNode]) -> int:
            nonlocal max_v
            # empty leaves have no value
            if node is None:
                return 0
            
            # get the max path sum for the left and right subtrees 
            # and the subtree consisting of the node itself
            l_val = helper(node.left)  + node.val
            r_val = helper(node.right) + node.val
            s_val = node.val
            # get the max path for the aborted path consisting of the v-path 
            # (left -> node -> right), the node itself, the left subtree and 
            # the right subtree
            v_val = max(l_val + r_val - s_val, s_val, l_val, r_val)

            # information concerning aborted paths isn't useful for calculations
            # at ancestor nodes, so we just set a nonlocal variable to keep track
            # of the max value seen so far
            max_v = max(max_v, v_val)
            # return the max path for the subtree consisting of the node itself
            # and optionally one of its subtrees. This will become a left or right
            # subtree for the parent of this node
            return max(l_val, r_val, s_val)
        
        max_p = helper(root)
        return max(max_v, max_p)

In [27]:
# test 1
root = TreeNode(1, TreeNode(2), TreeNode(3))
print(root)
expected = 6
print("result:", Solution().maxPathSum(root), "expected:", expected)
print()

# test 2
root = TreeNode(-10, TreeNode(9), TreeNode(20, TreeNode(15), TreeNode(7)))
print(root)
expected = 42
print("result:", Solution().maxPathSum(root), "expected:", expected)
print()

# test 3
root = TreeNode(-10, TreeNode(9), TreeNode(20, TreeNode(15, TreeNode(-5), TreeNode(-10)), TreeNode(7, None, TreeNode(20))))
print(root)
expected = 62
print("result:", Solution().maxPathSum(root), "expected:", expected)
print()

# test 4
root = TreeNode(-1, TreeNode(-2, TreeNode(-6), None), TreeNode(10, TreeNode(-3), TreeNode(-3)))
print(root)
expected = 10
print("result:", Solution().maxPathSum(root), "expected:", expected)
print()

# test 5
root = TreeNode(-1, None, TreeNode(9, TreeNode(-6), TreeNode(3, None, TreeNode(-2))))
print(root)
expected = 12
print("result:", Solution().maxPathSum(root), "expected:", expected)
print()

1(2(None)(None))(3(None)(None))
Node: 2 l_val: 2 r_val: 2 s_val: 2 v_val: 2 max_v: 2
Node: 3 l_val: 3 r_val: 3 s_val: 3 v_val: 3 max_v: 3
Node: 1 l_val: 3 r_val: 4 s_val: 1 v_val: 6 max_v: 6
result: 6 expected: 6

-10(9(None)(None))(20(15(None)(None))(7(None)(None)))
Node: 9 l_val: 9 r_val: 9 s_val: 9 v_val: 9 max_v: 9
Node: 15 l_val: 15 r_val: 15 s_val: 15 v_val: 15 max_v: 15
Node: 7 l_val: 7 r_val: 7 s_val: 7 v_val: 7 max_v: 15
Node: 20 l_val: 35 r_val: 27 s_val: 20 v_val: 42 max_v: 42
Node: -10 l_val: -1 r_val: 25 s_val: -10 v_val: 34 max_v: 42
result: 42 expected: 42

-10(9(None)(None))(20(15(-5(None)(None))(-10(None)(None)))(7(None)(20(None)(None))))
Node: 9 l_val: 9 r_val: 9 s_val: 9 v_val: 9 max_v: 9
Node: -5 l_val: -5 r_val: -5 s_val: -5 v_val: -5 max_v: 9
Node: -10 l_val: -10 r_val: -10 s_val: -10 v_val: -10 max_v: 9
Node: 15 l_val: 10 r_val: 5 s_val: 15 v_val: 15 max_v: 15
Node: 20 l_val: 20 r_val: 20 s_val: 20 v_val: 20 max_v: 20
Node: 7 l_val: 7 r_val: 27 s_val: 7 v_val: 27