<a href="https://colab.research.google.com/github/Victorlouisdg/simulators/blob/main/conjugate_gradient.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
np.random.seed(0)

In [2]:
def iterate_cg(A, x, r, p):
    """
    Basic iteration of the conjugate gradient algorithm

    Parameters:
    x: current iterate
    r: current residual
    p: current direction
    A: matrix of interest
    """

    Ap = A @ p
    rTr = np.dot(r, r)

    a = rTr / np.dot(p, Ap)
    x_new = x + a * p
    r_new = r + a * Ap

    # New p not needed for last iteration
    beta_new = np.dot(r_new, r_new) / rTr
    p_new = -r_new + beta_new * p

    return x_new, r_new, p_new


def run_conjugate_gradient(A, x0, b, max_iter=40):
    """
    Conjugate gradient algorithm

    Parameters:
    x0: initial point
    A: matrix of interest
    b: vector in linear system (Ax = b)
    max_iter: max number of iterations to run CG
    """

    # initial iteration
    xk = x0
    rk = A @ xk - b
    pk = -rk

    for i in (range(max_iter)):
        xk, rk, pk = iterate_cg(A, xk, rk, pk)
        e = np.sum(np.abs(rk))
        print(e, "   ", np.sum(A @ xk -b))

    pk = -rk

In [3]:
n = 50

A = np.identity(n)
A *= np.random.rand(n)

x0 = np.ones(n)
b = np.random.rand(n)

run_conjugate_gradient(A, x0, b)

7.715430637257627     -6.3349154853438545
11.162950873920975     -2.3418185592504557
9.358631407528712     1.63532454741279
7.029308633952032     0.9208721855641002
6.527974551648756     -0.0663242243141961
4.9760794716582115     0.15293991541694046
3.5280528860719507     -0.4058441770762788
2.319860607947089     -0.04547620038907216
2.0145632724361406     0.1479219538319713
1.182030340562115     0.11073720661931376
0.800327012805395     0.014843795825131595
0.39019226922741596     -0.027795685223465986
0.17899860747612845     -0.04229233424426922
0.15394265231436655     -0.006079697980735508
0.1309517469773319     0.024503309415939502
0.09281307797754822     0.0078531697427961
0.08068186002336536     -0.00887574966125194
0.10774140504415071     -0.016008698048967866
0.053629076283923896     0.0036865959597398665
0.038783108225437006     0.01050849843010967
0.02576450887996202     -0.0003568932176055183
0.009186039591557634     -0.00017422315001524354
0.004476445848841198     -0.000172