In [2]:
from typing import Tuple, List, Iterator
import itertools
import numpy as np

In [35]:
import numpy as np
from typing import Tuple, Iterator
import itertools

class MatmulProblem:
    def __init__(self, 
                    total_shape: Tuple[int, int, int], 
                    tile_sizes: np.ndarray):
        """
        tile_sizes: A numpy array of shape (3, P).
                    Row 0 = M sizes, Row 1 = K sizes, Row 2 = N sizes.
        """

        self.tile_sizes = np.asarray(tile_sizes, dtype=int)
        
        # 1. Validation: Shape must be (3, P)
        if self.tile_sizes.ndim != 2 or self.tile_sizes.shape[0] != 3:
            raise ValueError(f"Expected array of shape (3, P), got {self.tile_sizes.shape}")

        self.M, self.K, self.N = total_shape
        self.P = self.tile_sizes.shape[1] 
        self.N_tiles = self.P**3

        sums = self.tile_sizes.sum(axis=1)
        if not np.array_equal(sums, total_shape):
                raise ValueError(f"Splits sum to {sums}, expected {total_shape}")


        self.cuts = np.hstack([
            np.zeros((3, 1), dtype=int), 
            np.cumsum(self.tile_sizes, axis=1)
        ])
    

    def get_coords(self, tile_id): 
        assert 0 <= tile_id  and tile_id < self.N_tiles 
        p = self.P 
        return (tile_id % p, (tile_id//(p)) % p, (tile_id//(p*p)) % p)
      
    def get_slice(self, tile_id):
        assert 0 <= tile_id  and tile_id < self.N_tiles 
        m_idx, k_idx, n_idx = self.get_coords(tile_id)
        ms, me = self.cuts[0, m_idx], self.cuts[0, m_idx + 1]
        ks, ke = self.cuts[1, k_idx], self.cuts[1, k_idx + 1]
        ns, ne = self.cuts[2, n_idx], self.cuts[2, n_idx + 1]

        return (slice(ms, me, None), slice(ks, ke, None), slice(ns, ne, None))
    
    def get_tile_id(self, coords): 
        assert len(coords) == 3 
        i,k,j = coords
        assert 0 <= i and 0 <=j and 0 <=k     
        assert i < self.P and j < self.P and k < self.P
        return (i) + (k*self.P) + (j*self.P**2)