# Find the coefficients for an unbalanced chemical reaction

In [1]:
from scipy.linalg import solve
import string
import numpy as np

In [35]:
def get_element_counts(compound_str):
    """Given a compound_str like:
         "H2O"
       Return a dict with the count of each component element, e.g.:
         {"H":2,"O":1}
       Case matters!
    """
    if '(' in compound_str or ')' in compound_str:
        raise NotImplementedError("parentheses in a compound are not yet supported")  # TODO
    # break compound into substrings for each element, e.g. "H2O" -> ["H2", "O"]:
    elements = []
    element_start = 0
    for i, character in enumerate(compound_str):
        if i != 0 and character.isupper():
            elements.append(compound_str[element_start:i])
            element_start = i
    elements.append(compound_str[element_start:])
    # count up the number of atoms of each element:
    counts = {}
    for element in elements:
        e = element.rstrip(string.digits)
        count = element[len(e):]
        count = int(count) if len(count)>0 else 1
        if e in counts:  # e.g. if compound was "HC2H2O2", elements will contain H twice
            counts[e] += count
        else:
            counts[e] = count
    return counts

In [36]:
def balance_rxn(reaction_str):
    """Given a reaction_str like:
         "H2 + O2 -> H2O"
       Print the balanced reaction, e.g.:
         "2 H2 + 1 O2 -> 2 H2O"
       Case matters!
    """
    # Split up the reaction into individual terms:
    reactants, products = reaction_str.split("->")
    reactants = reactants.split("+")
    products  = products.split("+")
    reactants = [r.strip() for r in reactants]  # e.g. ['H2', 'O2']
    products  = [p.strip() for p in products]  # e.g. ['H2O']
    # Figure out what elements are involved in the reaction:
    elements = []
    for p in products:
        elements += list(get_element_counts(p).keys())
    for r in reactants:
        elements += list(get_element_counts(r).keys())
    elements = list(set(elements))
    # Set up the equation (matrix) to solve (Ax=b):
    num_coeffs = len(reactants) + len(products)
    num_elements = len(elements)  # this is also num_equations
    if num_elements < num_coeffs-1:
        raise ValueError("Unable to solve the given equation.")
    # if num_elements > num_coeffs:
    #     raise NotImplementedError  # remove linearly dependent row(s)
    A = np.zeros((num_elements,num_coeffs))
    b = np.zeros((num_elements,1))
    for eltidx, elt in enumerate(elements):
        for ridx, r in enumerate(reactants):
            A[eltidx,ridx] = get_element_counts(r).get(elt,0)
        for pidx, p in enumerate(products):
            A[eltidx,-len(products)+pidx] = -get_element_counts(p).get(elt,0)
    if num_elements < num_coeffs:
        # free variable, add a row to matrix with an arbitrary 
        # value (1) for one of the unknown coefficients
        A = np.concatenate([A,np.zeros((1,num_coeffs))])
        A[-1,0] = 1
        b = np.append(b,[1])
    # Solve it:
    coeffs = solve(a=A, b=b)
    if min(coeffs) < 1:
        # TODO: properly make all coeffs integers
        coeffs /= min(coeffs)
    solved_reactants = ["%0.2f %s" % (coeffs[ri],                r) for ri,r in enumerate(reactants)]  # TODO int not float
    solved_products  = ["%0.2f %s" % (coeffs[pi+len(reactants)], p) for pi,p in enumerate(products)]   # TODO int not float
    # Display the answer:
    print(" + ".join(solved_reactants) + " -> " + " + ".join(solved_products))

In [44]:
# TODO: speed up balance_rxn(), probably make it return a nice structure, maybe split everything up more too
# TODO?: make this a package, including tests

# Test it

In [38]:
balance_rxn("H2 + O2 -> H2O")

2.00 H2 + 1.00 O2 -> 2.00 H2O


In [39]:
balance_rxn("C4H10 + O2 -> CO2 + H2O")

1.00 C4H10 + 6.50 O2 -> 4.00 CO2 + 5.00 H2O


In [45]:
#balance_rxn("Cu + HNO3 -> Cu(NO3)2 + NO + H2O")
balance_rxn("Cu + HNO3 -> CuNO3NO3 + NO + H2O")

1.50 Cu + 4.00 HNO3 -> 1.50 CuNO3NO3 + 1.00 NO + 2.00 H2O
