# Experimentations and Implementations on Complexity Theory

## Showing that finding a path in a Directed Graph can be gotten in Polynomial Time

In [29]:
from typing import List

def get_nodes(edges: List[List[str]]):
    nodes = []

    for edge in edges:

        if edge[0] not in nodes:
            nodes.append(edge[0])

        if edge[1] not in nodes:
            nodes.append(edge[1])
            
    nodes.sort()
    return nodes


def checker(edges: List[List[str]], start: str, end: str) -> bool:
    
    marked = {start: True}

    for edge in edges: 
        if marked.get(edge[0]):
            marked[edge[1]] = True
            
    return True if marked.get(end) is not None else False


## Check that a path exists from a to f

edges = [["a", "b"], ["a", "c"], ["b", "e"], ["e", "c"], ["d", "f"]]
print(checker(edges, "a", "e"))

True


## Proof of Cook Levin Theorem: SAT is NP-complete

The goal of this proof is to show that any abitrary NP computation can be converted to a satisfiability problem (SAT).

### Note

The Turing Machine(TM) implemented here is an Deterministic Turing Machine(DTM) instead of a Non-Deterministic Turing Machine (NDTM). We do this because a DTM is more intuitive to understand. 

However, this doesn't undermine the proof because any problem solvable by NDTM can be solved by DTM and you could also see the computation of the DTM as a path of a NDTM

### Steps

- The `DTM` takes in an arbitrary problem, computes it and returns the computations. This computation is 2-dimensional array of state of the TM at every step of the computation. This array is called a `tableau`
- Using the `Cook_Levin_Prover`, we convert each symbol in the array to a boolean variable with true values in the form.$$x_{i,j,s}$$
where
    - $i$ is the row
    - $j$ is the column
    - $s$ is the value in that cell.

- Then, these four checks are done on these variables and ANDed together:
    1. We check that every entry (i, j) in the computation has exactly one value. Therefore, we check that each entry has one or more values and that each entry cannot have 2 values.
    2. We check that the start state is in the first position of the first row `and` other symbols.
    3. We check that an `ACCEPT_STATE` is in the tableau.
    4. We check that each row yields the next one.
    
### Resources

- [Cook Levin Theorem Proof](https://www.youtube.com/watch?v=LW_37i96htQ)
- [Cook Levin](https://www.cs.ubc.ca/~condon/cpsc506/handouts/Cook-Levin.pdf)

In [32]:
from typing import List, Dict, Optional, Tuple, NewType
from enum import Enum, auto
from functools import reduce

BLANK = "BLANK"

ACCEPT_STATE = "ACCEPT_STATE"

class DIRECTION(Enum):
    LEFT = auto()
    RIGHT = auto()
    
class STATE(Enum):
    ACCEPT = auto()
    REJECT = auto()
    RUNNING = auto()
    

Transition = NewType("Transition", Tuple[str, Optional[str], DIRECTION])

Transitions = NewType("Transitions", Dict[Tuple[str, str], Transition])

def pretty_print(computations):
    for step, config in enumerate(computations):
        print(f"Computation {step}: {config}")
        
class DTM:
    
    state = STATE.RUNNING
    
    head = []
            
    config: List[str] = []
    
    transitions: Transitions = {}
    
    computations: List[Dict] = []
        
    def __init__(self, limit: int, start_state: str, start_config: List[str], transitions: Transitions):
        start_config.insert(0, start_state)
        start_config = [BLANK for _ in range(limit)] + start_config + [BLANK for _ in range(limit)]
        self.config = start_config
        self.head = [start_state, limit]
        self.transitions = transitions
        self.computations = [start_config.copy()]
    
    @staticmethod
    def move(instruction: Tuple[Optional[str], str,  DIRECTION], config: List[str], head: List) -> Tuple[List[str], List]:
        config = config.copy()
        head = head.copy()
        index = head[1]
        (new_state, new_value, direction) = instruction
        if new_value == None:
            if direction == DIRECTION.LEFT:
                new_head_index = index - 1
                head[0] = new_state
                head[1] = new_head_index
            elif direction == DIRECTION.RIGHT:
                new_head_index = index + 1
                head[0] = new_state
                head[1] = new_head_index
        else:
            config[head[1]] = new_value
            if direction == DIRECTION.LEFT:
                new_head_index = index - 1
                head[0] = new_state
                head[1] = new_head_index
            elif direction == DIRECTION.RIGHT:
                new_head_index = index + 1
                head[0] = new_state
                head[1] = new_head_index
        return config, head
    
    @staticmethod
    def transition(transitions: Transitions, head: Tuple[str, int], config: List[str], computations: List[List[str]]):
        state = head[0]
        index = head[1]
        config.remove(state)
        symbol = config[index]
                                        
        move = transitions.get((state, symbol))
        
        if move == None:
            state = STATE.REJECT
            return state, head, config, computations
        elif move[0] == ACCEPT_STATE:
            state = STATE.ACCEPT
            new_state = move[0]
            new_value = move[1]
            direction = move[2]
            (config, head) = DTM.move((new_state, new_value, direction), config, head)
            state_index = head[1] - 1
            config.insert(state_index + 1, new_state)
            computations.append(config.copy())
            return state, head, config, computations
        
        new_state = move[0]
        new_value = move[1]
        direction = move[2]
        (config, head) = DTM.move((new_state, new_value, direction), config, head)
        state = STATE.RUNNING
        
        state_index = head[1] - 1
        config.insert(state_index + 1, new_state)
        computations.append(config.copy())
        return state, head, config, computations
                
            
    def run(self):
        (state, head, config, computations) = self.transition(self.transitions, self.head, self.config, self.computations)
        while state == STATE.RUNNING:
            (state, head, config, computations) = self.transition(self.transitions, head, config, computations)
        self.state = state
        self.head = head
        self.config = config
            
start_state = "q0"
start_config = ["0", "0", "0", "0"]
transitions = {
    ("q0", "0"): ("q1", BLANK, DIRECTION.RIGHT),
    ("q1", "x"): ("q1", None, DIRECTION.RIGHT),
    ("q1", "0"): ("q2", "x", DIRECTION.RIGHT),
    ("q2", "x"): ("q2", None, DIRECTION.RIGHT),
    ("q2", "0"): ("q3", None, DIRECTION.RIGHT),
    ("q3", "0"): ("q2", "x", DIRECTION.RIGHT),
    ("q2", BLANK): ("q4", None, DIRECTION.LEFT),
    ("q4", BLANK): ("q1", None, DIRECTION.RIGHT),
    ("q4", "x"): ("q4", None, DIRECTION.LEFT),
    ("q4", "0"): ("q4", None, DIRECTION.LEFT),
    ("q1", BLANK): (ACCEPT_STATE, None, DIRECTION.RIGHT),
}
limit = 2
tm = DTM(limit, start_state, start_config, transitions)
tm.run()
# pretty_print(tm.computations)

### Computation TO SAT

def bool_or(x: bool, y: bool):
    return x | y

def neg_bool_or(x: bool, y: bool):
    return (not x) | (not y)

def bool_and(x: bool, y: bool):
    return x & y

class Cook_Levin_Prover:
    
    def __init__(self, limit: int, computations: List[List[str]], symbols: List[str], states: List[str]):
        self.limit = limit
        self.computations = computations
        self.symbols = symbols
        self.states = states

    @staticmethod
    def computations_to_boolean_vars(computations: List[List[str]]) -> List[List[str]]:
        """
        Converts a set of computations to a booleans variables of true values
        """
        boolean_vars = []
        boolean_values = {}

        for i, computation in enumerate(computations):
            b_vars = [] 
            for j, symbol in enumerate(computation):
                name = f"x_{i}_{j}_{symbol}"
                b_vars.append(name)
                boolean_values[name] = True
            boolean_vars.append(b_vars)
        
        return boolean_vars, boolean_values
    
    @staticmethod
    def condition_1(symbols: List[str], boolean_vars: List[List[str]], boolean_values: Dict[str, bool]) -> bool:
        """
        We check that every entry (i, j) in the computation
        has exactly one value. Therefore, we check that each
        entry has one or more values and that each entry cannot
        have 2 values.
        """
        
        res = True
        for i, x in enumerate(boolean_vars):
            for j, y in enumerate(x):
                list_0 = []
                list_1 = []
                for s in symbols:
                    key_1 = f"x_{i}_{j}_{s}"
                    value_1 = (True if boolean_values.get(key_1) == True else False)
                    list_0.append(value_1)
                    for s in symbols:
                        key_2 = f"x_{i}_{j}_{s}"
                        if key_1 == key_2:
                            continue
                        value_2 = True if boolean_values.get(key_2) == True else False
                        list_1.append(neg_bool_or(value_1, value_2))
                        
                s_0 = reduce(bool_or, list_0)
                s_1 = reduce(bool_and, list_1)
                s_2 = bool_and(s_0, s_1)
                res = bool_and(res, s_2)
                
        return res
    
    @staticmethod
    def condition_2(limit:int, boolean_vars: List[List[str]], boolean_values: Dict[str, bool]) -> bool:
        """
        We check that the start state is in the first position of the first row `and` other symbols
        """
        start_state = f"x_{0}_{limit}_q0"
        first_row = [boolean_values.get(var) for var in boolean_vars[0]]
        return reduce(bool_and, first_row) and (start_state == boolean_vars[0][limit])
    
    @staticmethod
    def condition_3(boolean_vars: List[List[str]], boolean_values: Dict[str, bool]) -> bool:
        """
        We check that an `ACCEPT_STATE` is in the tableau
        
        This is achieved by looking for an `ACCEPT_STATE` across all the cells(`boolean_vars[i][j]`)
        of the tableau and "ORing" them together.
        
        False is returned if not found.
        """
        accept_states = [boolean_values.get(var) for row in boolean_vars for var in row if ACCEPT_STATE in var]
        ## This serves as an identity element for when there are not accept states
        accept_states.append(False)
        return reduce(bool_or, accept_states)
    
    @staticmethod
    def condition_4(computations: List[List[str]], states: List[str]) -> bool:
        """
        We check that each row is a transition of the previous row
        """
        
        def get_state_and_index(computation: List[str], states: List[str]):
            for index, symbol in enumerate(computation):
                if symbol in states:
                    return index, symbol
                
        def get_windows(state_index: int, current_row: List[str], next_row: List[str]) -> Tuple[List[List[str]], List[List[str]], List[List[str]]]:
            before = state_index - 1
            after = state_index + 1
            left_window = [current_row[:before], next_row[:before]]
            middle_window = [current_row[before:after + 1], next_row[before:after + 1]]
            right_window = [current_row[after + 1:], next_row[after + 1:]]
            return (left_window, middle_window, right_window)
        
        boolean_values = []
        
        for row_index in range(len(computations)):
            if row_index == (len(computations) - 1):
                continue
            current_row = computations[row_index]
            next_row = computations[row_index + 1]
            state_in_current_row, index = get_state_and_index(current_row, states)
            (lw, mw, rw) = get_windows(state_in_current_row, current_row, next_row)
            boolean_values.append(lw[0] == lw[1])
            boolean_values.append(rw[0] == rw[1])
            boolean_values.append(mw[0] != mw[1])
            
        return reduce(bool_and, boolean_values)
    
            
    def run(self):
        boolean_vars, boolean_values = self.computations_to_boolean_vars(self.computations)
        condition_1 = self.condition_1(self.symbols, boolean_vars, boolean_values)
        condition_2 = self.condition_2(self.limit, boolean_vars, boolean_values)
        condition_3 = self.condition_3(boolean_vars, boolean_values)
        condition_4 = self.condition_4(self.computations, self.states)
        final_conditions = [condition_1, condition_2, condition_3, condition_4]
        return reduce(bool_and, final_conditions)


symbols = ["q0", "q1", "q2", "q3", "q4", "0", "x", BLANK, ACCEPT_STATE]
states = ["q0", "q1", "q2", "q3", "q4", ACCEPT_STATE]
cook_levin_prover = Cook_Levin_Prover(limit, tm.computations, symbols, states)
cook_levin_prover.run()

True

## Freivalds Algorithm

In [35]:
"""
- The product of two matrices is equalavent to performing two transformation

- FORMULAR FOR PRODUCT OF TWO n x n MATRICES

- The order matters

 M2 = [[a, b],[c, d]]
 
 M1 = [[e, f],[g, h]]
 
 M2 * [e, g] = e * [a, c] + g * [b, d] = [(a * e) + (b * g), (c * e) + (d * g)]
 
 M2 * [f, h] = f * [a, c] + h * [b, d] = [(a * f) + (b * h), (c * f) + (d * h)]
 
 M2 * M1 = [[((a * e) + (b * g)), ((a * f) + (b * h))],[((c * e) + (d * g)), ((c * f) + (d * h))]]
"""

from typing import List, Tuple
import numpy as np

# for i in range(3):
#     for j in range(3):
#         print(j, i)
# This shows that how to index by (column, row) rather than (row, column)

def scalar_mul(a: int, b: List[int]):
    res = []
    for value in b:
        res.append(a * value)
    return res

def combine(a: List[int], b: List[int]):
    if len(a) == 0 and len(b) != 0:
        return b
    elif len(b) == 0 and len(a) != 0:
        return a
    
    assert len(a) == len(b)
    
    res = []
    for i, j in zip(a, b):
        res.append(i + j)
    return res

def linear_combination(a: List[int], b: List[List[int]]):
    assert len(a) == len(b)
    n = len(a)
    res = []
    for i in range(n):
        scalar = a[i]
        vector = b[i]
        mul = scalar_mul(scalar, vector)
        res = combine(res, mul)
    return res
        
    
def matrix_product(A, B):
    n = len(A)
    # Initialize an n x n matrix with zeros for the result
    C = [[0] * n for _ in range(n)]
    
    # Perform matrix multiplication
    for i in range(n):
        for j in range(n):
            for k in range(n):
                C[i][j] += A[i][k] * B[k][j]
    
    return C

############################
### Freivald's Algorithm ###
############################

"""
First,choose a random r ∈ Fp,and let x = (1, r, r2,...,rn−1). Then compute y = Cx and z = A·Bx, outputting YES if y = z and NO otherwise.
"""

import random

p = 79

r = random.randint(5, p - 1)

A = [[0, 2], [1, 0]]

B = [[1, -2], [1, 0]]

x = [r ** i for i in range(len(A))]

C = matrix_product(A, B)

print(x)

Bx = linear_combination(x, B)

print(Bx)

z = linear_combination(Bx, A)

print(z)

y = linear_combination(x, C)

print(y)

# assert z == y

[-2, 100]
[51, -98]


## The Sumcheck Protocol

In [16]:
from typing import List, Tuple, Union
from functools import reduce
from random import randint
from math import log2

class FFE:
    
    element: int
    
    modulus: int
    
    def __init__(self, element: int, modulus: int):
        self.element = element % modulus
        self.modulus = modulus
        
    def __add__(self, rhs: 'FFE'):
        if self.modulus != rhs.modulus:
            raise Exception("Modulus Mismatch")
        element = (self.element + rhs.element) % self.modulus
        return FFE(element, self.modulus)
        
    def __sub__(self, rhs: 'FFE'):
        if self.modulus != rhs.modulus:
            raise Exception("Modulus Mismatch")
        element = (self.element - rhs.element) % self.modulus
        return FFE(element, self.modulus)
    
    def __mul__(self, rhs: 'FFE'):
        if self.modulus != rhs.modulus:
            raise Exception("Modulus Mismatch")
        element = (self.element * rhs.element) % self.modulus
        return FFE(element, self.modulus)
    
    def __eq__(self, rhs: 'FFE'):
        return (self.element == rhs.element) and (self.modulus == rhs.modulus)
    
    def __str__(self):
        return f"FFE(element: {self.element}, modulus: {self.modulus})"
    
    def __repr__(self):
        return f"FFE(element: {self.element}, modulus: {self.modulus})"
    
    
class FF:
    
    modulus: int
    
    def __init__(self, modulus: int):
        self.modulus = modulus
        
    def new(self, element: int) -> FFE:
        return FFE(element, self.modulus)
    
    def zero(self) -> FFE:
        return self.new(0)
    
    def one(self) -> FFE:
        return self.new(1)
    
    def ffe_to_binary(self, n: FFE) -> List[FFE]:
        binary_string_no_prefix = bin(n.element)[2:]
        res = []
        for s in binary_string_no_prefix:
            if s == "1":
                res.append(self.one())
            else:
                res.append(self.zero())
        return res

class Linear_Polynomial:
    
    evaluations: List[FFE]
    
    num_of_vars = 1
    
    def __init__(self, evaluations: List[FFE]):
        if len(evaluations) != 2:
            raise Exception("Invalid evaluations")
        self.evaluations = evaluations
        
    def evaluate(self, x: FFE) -> FFE:
        evaluations = self.evaluations.copy()
        eval_1 = evaluations[0]
        eval_2 = evaluations[1]
        return (eval_1 - (eval_1 * x)) + (eval_2 * x)
    
    def __str__(self):
        return f"Linear Polynomial({self.evaluations})"
    
    def __repr__(self):
        return f"Linear Polynomial({self.evaluations})"
    
    def __add__(self, rhs: 'Linear_Polynomial'):
        if (self.num_of_vars != rhs.num_of_vars):
            raise Exception("Invalid Operation")
        new_evals = [(x + y) for (x, y) in zip(self.evaluations, rhs.evaluations)]
        new_poly = Linear_Polynomial(new_evals)
        return new_poly

    
class Multilinear_Polynomial:
    
    evaluations: List[FFE]
    
    num_of_vars: int
    
    num_of_evaluation_points: int
    
    ff: FF
    
    def __init__(self, ff: FF, evaluations: List[FFE], num_of_vars: int):
        self.ff = ff
        if num_of_vars < 2:
            raise Exception("Invalid number of variables")
        num_of_evaluation_points = 2 ** num_of_vars
        self.num_of_evaluation_points = num_of_evaluation_points
        length_of_evaluation_points = len(evaluations)
        if length_of_evaluation_points == num_of_evaluation_points:
            self.num_of_vars = num_of_vars
            self.evaluations = evaluations
        elif length_of_evaluation_points < num_of_evaluation_points:
            padding = [ff.zero() for _ in range(num_of_evaluation_points - length_of_evaluation_points)]
            evaluations = evaluations + padding
            self.evaluations = evaluations
            self.num_of_vars = num_of_vars
        else:
            self.evaluations = evaluations[:num_of_evaluation_points]
            self.num_of_vars = num_of_vars
            
    def __str__(self):
        return f"Multilinear_Polynomial({self.evaluations})"
    
    def __repr__(self):
        return f"Multilinear_Polynomial({self.evaluations})"
    
    def __add__(self, rhs: 'Multilinear_Polynomial'):
        if (self.ff.modulus != rhs.ff.modulus) or (self.num_of_vars != rhs.num_of_vars):
            raise Exception("Invalid Operation")
        new_evals = [(x + y) for (x, y) in zip(self.evaluations, rhs.evaluations)]
        new_poly = Multilinear_Polynomial(self.ff, new_evals, self.num_of_vars)
        return new_poly
    
    def __mul__(self, rhs: 'Multilinear_Polynomial'):
        if (self.ff.modulus != rhs.ff.modulus) or (self.num_of_vars != rhs.num_of_vars):
            raise Exception("Invalid Operation")
        new_evals = []
        for x in self.evaluations:
            for y in rhs.evaluations:
                new_evals.append(x * y)
        new_poly = Multilinear_Polynomial(self.ff, new_evals, self.num_of_vars ** 2)
        print(self.num_of_vars)
        return new_poly
    
    @staticmethod
    def get_pairing_index(var_index: int, num_of_vars: int):
        if var_index < 0 or var_index > num_of_vars - 1:
            raise Exception("Invalid variable index")
        num_of_evals = 2 ** num_of_vars
        offset = num_of_evals // (2 ** (var_index + 1))
        indices = list(range(num_of_evals))
        pairs = []
        for i in range(num_of_evals // 2):
            start = indices[0]
            stop = start + offset
            pair = (start, stop)
            indices.remove(start)
            indices.remove(stop)
            pairs.append(pair)
        return pairs
      
    def partial_evaluate(self, var_index: int, var: FFE) -> Union[Linear_Polynomial, 'Multilinear_Polynomial']:
        pairs = self.get_pairing_index(var_index, self.num_of_vars)
        
        new_points = []
        
        for pair in pairs:
            y_1 = self.evaluations[pair[0]]
            x_2 = self.ff.one()
            y_2 = self.evaluations[pair[1]]
            slope = y_2 - y_1
            c = y_2 - (slope * x_2)
            value = (slope * var) + c
            new_points.append(value)
        
        if len(new_points) == 2:
            return Linear_Polynomial(new_points)
        else:
            num_of_vars = log2(len(new_points))
            return Multilinear_Polynomial(self.ff, new_points, int(num_of_vars))
        
    
    def evaluate(self, variables: Tuple[FFE]):
        if len(variables) != self.num_of_vars:
            raise Exception("Wrong number of variables")
        
        var_index = 0 # we use the first variable `a`
        var = variables[var_index]
        new_poly = self.partial_evaluate(var_index, var)
        while new_poly.num_of_vars != 1:
            variables = variables[1:]
            var = variables[var_index]
            new_poly = new_poly.partial_evaluate(var_index, var)
        return new_poly.evaluate(variables[-1])



class SumCheck_Protocol:
    
    @staticmethod
    def sum(poly: Multilinear_Polynomial):
        return reduce(lambda x, y: x + y,poly.evaluations)
    
    
"""
THE SUMCHECK PROTOCOL

GOAL: A prover wants to prove to the verifier the sum of a polynomial over the boolean hypercube

SETUP:
    
    Polynomial: f(x, y, z) = x + y + z + 7

"""

ff = FF(17)
evaluations = [ff.new(7), ff.new(8), ff.new(8), ff.new(9), ff.new(8), ff.new(9), ff.new(9), ff.new(10)]
mult_1 = Multilinear_Polynomial(ff, evaluations, 3)

## STEP 1: The prover computes the sum of the polynomial over the boolean hypercube and sends to the
## verifier

sum_check = SumCheck_Protocol()
h = sum_check.sum(mult_1)

## STEP 2: The prover computes a polynomial g_1(x) and sends to the verifier
poly_1 = mult_1.partial_evaluate(2, ff.zero()) ## f(x, y, 0)
poly_2 = mult_1.partial_evaluate(2, ff.one()) ## f(x, y, 1)
poly_3 = poly_1 + poly_2 ## f(x, y) = f(x, y, 0) + f(x, y, 1)
poly_4 = poly_3.partial_evaluate(1, ff.zero()) ## f(x, 0)
poly_5 = poly_3.partial_evaluate(1, ff.one()) ## f(x, 1)
poly_6 = poly_4 + poly_5 ## f(x) = f(x, 0) + f(x, 1)
g_1_of_x = poly_6

## STEP 3: The verifier stores g_1(x) as s_1(x)
s_1_of_x = g_1_of_x

## STEP 4: The verifier checks that h == s_1(0) + s_1(1)
x = s_1_of_x.evaluate(ff.zero())
y = s_1_of_x.evaluate(ff.one())
assert h == x + y, "Invalid proof"

## STEP 5: The verifier generates a random value `r1` and sends to the prover
r1 = ff.new(randint(0, 10000000))

## STEP 6: The prover computes g_2(y) = f(r1, y, 0) + f(r1, y, 1) and sends to the verifier
poly_1 = mult_1.partial_evaluate(2, ff.zero()) ## f(x, y, 0)
poly_2 = mult_1.partial_evaluate(2, ff.one()) ## f(x, y, 1)
poly_3 = poly_1 + poly_2 ## f(x, y) = f(x, y, 0) + f(x, y, 1)
poly_4 = poly_3.partial_evaluate(0, r1) ## f(r1, y)
g_2_of_y = poly_4

## STEP 7: The verifier stores g_2(y) as s_2(y)
s_2_of_y = g_2_of_y

## STEP 8: The verifier checks that s_1(r1) == s_2(0) + s_2(1)
x = s_1_of_x.evaluate(r1)
y = s_2_of_y.evaluate(ff.zero())
z = s_2_of_y.evaluate(ff.one())
assert x == y + z, "Invalid proof"

## STEP 9: The verifier generates a random value `r2` and sends to the prover
r2 = ff.new(randint(0, 10000000))

## STEP 10: The prover computes g_3(z) = f(r1, r2, z) and sends to the verifier
poly_1 = mult_1.partial_evaluate(0, r1) ## f(r1, y, z)
poly_2 = poly_1.partial_evaluate(0, r2) ## f(r1, r2, z)
g_3_of_z = poly_2

## STEP 11: The verifier stores g_3(z) as s_3(y)
s_3_of_z = g_3_of_z

## STEP 12: The verifier checks that s_2(r2) == s_3(0) + s_3(1)
x = s_2_of_y.evaluate(r2)
y = s_3_of_z.evaluate(ff.zero())
z = s_3_of_z.evaluate(ff.one())
assert x == y + z, "Invalid proof"

## STEP 13: The verifier generates a random value `r3` and sends to the prover
r3 = ff.new(randint(0, 10000000))

## STEP 14: The prover computes f(r1, r2, r3) and sends to the verifier
evaluation = mult_1.evaluate((r1, r2, r3))

## LAST STEP: The verifier checks that s_3(r3) == f(r1, r2, r3)
x = s_3_of_z.evaluate(r3)
assert x == evaluation, "Invalid proof"

### Interactive Proof(IP) for Matrix Multiplication(`MatMult`) Using the Sumcheck Protocol

The best known algorithm for `MatMult` runs in $O(n^{2.37286})$ but it's not pratical. 

Check it out [here](https://arxiv.org/abs/2010.05846).

We also showed how to check that product matrix `A.B == C` using the Freivalds Algorithm shown above.

Here, we would show how to make the checking an IP using the Sumcheck Protocol.

That is, the prover wants to prove to the verifier that he knows a matrix C such that product matrix of `A` and `B` equals `C`.



In [21]:
n = 16

log_n = int(log2(n))

ff = FF(17)


def generate_n_by_n_matrix(n):
    
    res = []
    
    for _ in range(n):
        row = [ff.new(randint(0, 9999999)) for _ in range(n)]
        res.append(row)
        
    return res


def matrix_product(ff: FF, A: List[List[FFE]], B: List[List[FFE]]):
    n = len(A)
    # Initialize an n x n matrix with zeros for the result
    C = [[ff.zero()] * n for _ in range(n)]
    
    # Perform matrix multiplication
    for i in range(n):
        for j in range(n):
            for k in range(n):
                C[i][j] += A[i][k] * B[k][j]
    
    return C
    
a = generate_n_by_n_matrix(n)
b = generate_n_by_n_matrix(n)

c = matrix_product(ff, a, b)

def matrix_to_one_dim_list(x: List[List[FFE]]):
    
    res = []
    
    for row in x:
        res += row
        
    return res

num_of_vars = log_n + log_n

evaluations_a = matrix_to_one_dim_list(a)
multi_poly_a = Multilinear_Polynomial(ff, evaluations_a, num_of_vars)

evaluations_b = matrix_to_one_dim_list(b)
multi_poly_b = Multilinear_Polynomial(ff, evaluations_b, num_of_vars)

evaluations_c = matrix_to_one_dim_list(c)
multi_poly_c = Multilinear_Polynomial(ff, evaluations_c, num_of_vars)

r = [ff.new(randint(2, log_n - 1)) for _ in range(num_of_vars)]

c = multi_poly_c.evaluate(r)

b_values = []

possible_values = [ff.zero(), ff.one()]

for i in possible_values:
    for j in possible_values:
        for k in possible_values:
            for l in possible_values:
                b_values.append([i, j, k, l])

summation = ff.zero()

for values in b_values:
    input_for_a = r[:4] + values
    input_for_b = values + r[4:]
    a = multi_poly_a.evaluate(input_for_a)
    b = multi_poly_b.evaluate(input_for_b)
    summation = summation + (a * b)

assert c == summation

### Interactive Proof