In [4]:
import random

class BinaryTree:
    def __init__(self, content):
        self.content = content
        self.left = None
        self.right = None
        
        self.depth = -1
        
    def __str__(self):
        return "(" + str(self.content) + " ( " + str(self.left) + " | "  \
    + str(self.right) + "))"

In [5]:
# O(n^2) naive algorithm
# depth is called repeatedly on the same nodes

def is_balanced_binary_tree(btree):
    if btree is None:
        return True
    
    # check all the nodes in a tree
    return (abs(depth(btree.left) - depth(btree.right)) <= 1 and \
           is_balanced_binary_tree(btree.left) and \
            is_balanced_binary_tree(btree.right))

def depth(btree):
    if btree is None:
        return 0
    else:
        if btree.depth != -1:
            return btree.depth
        else:
            btree.depth = 1 + max(depth(btree.left), depth(btree.right))
            return btree.depth

In [6]:
# efficient algorithm, get heights of subtrees and 
# check subtrees if balanced at the same time

def is_balanced_binary_tree2(btree):
    return checkBalance(btree)[0]

def checkBalance(btree):
    # base case
    if btree == None:
        return True, 0
    
    # general relationship
    left_balanced, left_depth = checkBalance(btree.left)
    right_balanced, right_depth = checkBalance(btree.right)
    
    balanced = (abs(left_depth - right_depth) <= 1) and\
        left_balanced and right_balanced
    btree.depth = 1 + max(left_depth, right_depth)
    
    return balanced, btree.depth

In [7]:
# test case1

bt = BinaryTree(random.randint(0, 100))
print (bt)
for c1 in range(0, 5):
    bt2 = BinaryTree(random.randint(0, 100))
    print(bt2)
    bt2.left = bt
    bt = bt2

(69 ( None | None))
(14 ( None | None))
(51 ( None | None))
(82 ( None | None))
(9 ( None | None))
(21 ( None | None))


In [8]:
is_balanced_binary_tree(bt)

False

In [10]:
# test case2

def make_random_balanced_tree(depth):
    if depth > 0:
        tree = BinaryTree(random.randint(0, 100))
        tree.left = make_random_balanced_tree(depth - 1)
        tree.right = make_random_balanced_tree(depth - 1)
        return tree
    else:
        return None
    
balanced_tree = make_random_balanced_tree(3)
print (balanced_tree)

(50 ( (14 ( (79 ( None | None)) | (66 ( None | None)))) | (87 ( (97 ( None | None)) | (99 ( None | None))))))
