<h1 align=center> AlphaTensor </h1>

Loading factorizations found by AlphaTensor and recombination.

- Copyright 2022 DeepMind Technologies Limited
- All software is licensed under the Apache License, Version 2.0 (Apache 2.0); you may not use this file except in compliance with the Apache 2.0 license. You may obtain a copy of the Apache 2.0 license at: https://www.apache.org/licenses/LICENSE-2.0
- All other materials are licensed under the Creative Commons Attribution 4.0 International License (CC-BY).  You may obtain a copy of the CC-BY license at: https://creativecommons.org/licenses/by/4.0/legalcode
- Unless required by applicable law or agreed to in writing, all software and materials distributed here under the Apache 2.0 or CC-BY licenses are distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the licenses for the specific language governing permissions and limitations under those licenses.
- This is not an official Google product.

## Algorithms

In [None]:
import numpy as np
from google.colab import files

### Loading Required Files

In [None]:
!git clone https://github.com/google-deepmind/alphatensor.git
%cd alphatensor/algorithms

Upload one of the two files provided in the same folder: `factorizations_r.npz` (algorithms in standard arithmetic) or `factorizations_f2.npz` (algorithms in arithmetic modulo 2).

In [None]:
!ls

In [None]:
# Interactive upload
# uploaded = files.upload()
# filename = list(uploaded.keys())[0]
# with open(filename, 'rb') as f:
#   factorizations = dict(np.load(f, allow_pickle=True))

In [None]:
# @title Default title text
FILE = "factorizations_r.npz" # @param ["factorizations_r.npz","factorizations_f2.npz"] {"allow-input":true}
with open("factorizations_r.npz", 'rb') as f:
  factorizations = dict(np.load(f, allow_pickle=True))

In [None]:
factorizations.keys()

In [None]:
# Print available factorizations and their shapes.
for key in factorizations:
  u, v, w = factorizations[key]
  rank = u.shape[-1]
  assert rank == v.shape[-1] and rank == w.shape[-1]
  print(f'{key}: rank={u.shape[-1]}')

In [None]:
factorizations["2,2,2"]

Please note that as provided, the factorizations decompose the *symmetrized* version of the matrix multiplication tensor, representing the bilinear operation $\mathbf{A}, \mathbf{B} \mapsto (\mathbf{A} \cdot \mathbf{B})^T$. This is standard in the literature, and factorizations can be easily converted
between the symmetrized and non-symmetrized versions.

In [None]:
def get_mamu_tensor_rectangular(a: int, b: int, c: int) -> np.ndarray:
  """Returns the symmetrized matrix multiplication tensor T_{a, b, c}."""
  result = np.full((a*b, b*c, c*a), 0, dtype=np.int32)
  for i in range(a):
    for j in range(b):
      for k in range(c):
        result[i * b  + j][j * c + k][k * a + i] = 1
  return result


# Test correctness of a factorization.
tensor = get_mamu_tensor_rectangular(3, 4, 5)
u, v, w = factorizations['3,4,5']
reconstruction = np.einsum('ir,jr,kr->ijk', u, v, w)
if np.array_equal(tensor, reconstruction):
  print('Factorization is correct in R (standard arithmetic).')
elif np.array_equal(tensor, np.mod(reconstruction, 2)):
  print('Factorization is correct in F2 (modular arithmetic).')
else:
  print('Factorization is incorrect.')

<h2 style="text-align: center; color: blue"> Application </h2>

Here we actually get to apply the factorizations found by AlphaTensor to speed up multiplications of matrices we are interested in.

In [None]:
import numpy as np

# The factor data for "2,2,2" as you provided
factor_data = factorizations["2,2,2"]

# Dimensions for 2x2 multiplication
m, k, n = 2, 2, 2
R = 7  # Rank for 2x2x2 is 7

# Extract U, V, W matrices (collections of factors)
# U_all_ranks has shape (m*k, R) -> (4, 7)
# V_all_ranks has shape (k*n, R) -> (4, 7)
# W_all_ranks has shape (m*n, R) -> (4, 7)
U_all_ranks = factor_data[0]
V_all_ranks = factor_data[1]
W_all_ranks = factor_data[2]

# Let's define two toy 2x2 matrices A and B
A = np.array([[1, 2],
              [3, 4]], dtype=np.int32)

B = np.array([[5, 6],
              [7, 8]], dtype=np.int32)

print("Matrix A:\n", A)
print("Matrix B:\n", B)

# --- AlphaTensor Multiplication ---

# 1. Calculate M_r values (R of them)
M_values = np.zeros(R, dtype=np.int32)
for r in range(R):
    u_r_flat = U_all_ranks[:, r]      # Shape (4,)
    u_r_matrix = u_r_flat.reshape(m, k) # Shape (2, 2)
    M_values[r] = np.sum(u_r_matrix * A)

# 2. Calculate N_r values (R of them)
N_values = np.zeros(R, dtype=np.int32)
for r in range(R):
    v_r_flat = V_all_ranks[:, r]      # Shape (4,)
    v_r_matrix = v_r_flat.reshape(k, n) # Shape (2, 2)
    N_values[r] = np.sum(v_r_matrix * B)

# 3. Calculate P_r = M_r * N_r (the R scalar multiplications)
P_values = M_values * N_values

# 4. Reconstruct C_alpha
C_alpha = np.zeros((m, n), dtype=np.int32)
for r in range(R):
    w_r_flat = W_all_ranks[:, r]      # Shape (4,)

    # Assumes row-major order (C order), which is NumPy's default.
    # Results in a transposed C Matrix
    # w_r_matrix = w_r_flat.reshape(m, n) # Shape (2, 2)

    # use Fortran order when reshaping the W factor's flat data.
    # to avoid getting a transpose of our intended C Matrix
    w_r_matrix = W_all_ranks[:, r].reshape((m, n), order='F')
    C_alpha += P_values[r] * w_r_matrix

print("\nResult C_alpha (using AlphaTensor factors):\n", C_alpha)

# --- Standard Matrix Multiplication for comparison ---
C_standard = A @ B
# Or using np.matmul(A, B)

print("\nResult C_standard (using A @ B):\n", C_standard)

# Check if the results are the same
if np.array_equal(C_alpha, C_standard):
    print("\nSuccess! AlphaTensor's method matches standard multiplication.")
else:
    print("\nError! Results do not match.")
    print("Difference:", C_alpha - C_standard)

<h3 style="text-align: center"> General Fast Multiplication Algorithm </h3>

This is the holy grail we have been searching for all this time, **a general way to fast multiply arbitrarily sized matrices**

In [None]:
# @title General Fast Multiplication Algorithm
def multiply_with_alphatensor_factors(A, B, key, factorizations):
    """
    Multiplies matrices A and B using AlphaTensor factors for the given key.

    Args:
        A (np.ndarray): The first input matrix.
        B (np.ndarray): The second input matrix.
        key (str): The key corresponding to the factorization (e.g., "m,k,n").
        factorizations (dict): A dictionary where keys are "m,k,n" strings
                               and values are the (3, size, R) factor tensors.

    Returns:
        np.ndarray: The resulting matrix C = A @ B.
        Or None if dimensions are mismatched or key is not found.
    """
    if key not in factorizations:
        print(f"Error: Factorization key '{key}' not found.")
        return None

    # 1. Parse m, k, n from the key
    try:
        m_str, k_str, n_str = key.split(',')
        m, k, n = int(m_str), int(k_str), int(n_str)
    except ValueError:
        print(f"Error: Key '{key}' is not in the format 'm,k,n'.")
        return None

    # 2. Validate input matrix dimensions
    if A.shape != (m, k):
        print(f"Error: Matrix A shape {A.shape} does not match key dimensions ({m},{k}).")
        return None
    if B.shape != (k, n):
        print(f"Error: Matrix B shape {B.shape} does not match key dimensions ({k},{n}).")
        return None

    # 3. Extract factor data and Rank (R)
    factor_data = factorizations[key] # This is the (3, size_flat, R) array
    # U_all_ranks has shape (m*k, R)
    # V_all_ranks has shape (k*n, R)
    # W_all_ranks has shape (m*n, R)
    U_all_ranks = factor_data[0]
    V_all_ranks = factor_data[1]
    W_all_ranks = factor_data[2]

    # The rank R is the number of columns in U_all_ranks (or V or W)
    if U_all_ranks.shape[0] != m * k:
        print(f"Error: U factor dimension mismatch. Expected {m*k}, got {U_all_ranks.shape[0]}")
        return None
    if V_all_ranks.shape[0] != k * n:
        print(f"Error: V factor dimension mismatch. Expected {k*n}, got {V_all_ranks.shape[0]}")
        return None
    if W_all_ranks.shape[0] != m * n:
        print(f"Error: W factor dimension mismatch. Expected {m*n}, got {W_all_ranks.shape[0]}")
        return None

    R = U_all_ranks.shape[1]
    if V_all_ranks.shape[1] != R or W_all_ranks.shape[1] != R:
        print(f"Error: Rank mismatch between U, V, W factors for key '{key}'.")
        return None

    print(f"Using AlphaTensor factors for {m}x{k} @ {k}x{n} multiplication with Rank R={R}")

    # --- AlphaTensor Multiplication Logic ---
    A_flat = A.flatten() # Should be row-major, which is NumPy's default
    B_flat = B.flatten()

    # 1. Calculate M_r values
    # M_values[r] = sum_{i,j} U_all_ranks[idx(i,j), r] * A[i,j]
    # This can be done efficiently with dot products if A is flattened
    # M_values[r] = U_all_ranks[:, r].reshape(m, k) * A  then sum over all elements
    # OR: M_values[r] = np.dot(A_flat, U_all_ranks[:,r]) if U factors were stored as (R, m*k)
    # Given U_all_ranks is (m*k, R), we sum (U_r_matrix * A)
    M_values = np.zeros(R, dtype=A.dtype) # Match dtype of input
    for r_idx in range(R):
        u_r_matrix = U_all_ranks[:, r_idx].reshape(m, k)
        M_values[r_idx] = np.sum(u_r_matrix * A)


    # 2. Calculate N_r values
    N_values = np.zeros(R, dtype=B.dtype)
    for r_idx in range(R):
        v_r_matrix = V_all_ranks[:, r_idx].reshape(k, n)
        N_values[r_idx] = np.sum(v_r_matrix * B)

    # 3. Calculate P_r = M_r * N_r
    P_values = M_values * N_values

    # 4. Reconstruct C_alpha
    C_alpha = np.zeros((m, n), dtype=P_values.dtype)
    for r_idx in range(R):
        # Use order='F' for W as previously determined!
        w_r_matrix = W_all_ranks[:, r_idx].reshape((m, n), order='F')
        C_alpha += P_values[r_idx] * w_r_matrix

    return C_alpha

In [None]:
A_2 = np.array([[1, 2, 3],
                [4, 5, 6],
                [7, 8, 9]])
B_2 = np.array([[1, 2, 3],
                [4, 5, 6],
                [7, 8, 9]])

multiply_with_alphatensor_factors(A_2, B_2, "3,3,3", factorizations)

<h3 style="text-align: center; color: green"> Future Work </h3>

1. Compute the factorization dictionary key from the shape of the input matrices.
2. How is this to be implemented in a lower level language (C, C++)
3. How is this to be implemented on the GPU.  
- CUDA C++
- Parallelization
4. Can this be implemented on a TPU?
5. How do we error handle for incompatible matrix shapes?

## Recombination

This expands on the basic cases that AlphaTensor found, specifically for larger matrix sizes. It essentially
> provides tools to construct new U, V, W factorizations for larger matrix sizes by combining known factorizations for smaller ones.

In [None]:
%cd /content

In [None]:
!python3 -m alphatensor.recombination.example