This example demonstrates the API to register custom operator implementations for specific input and output tensor formats. This example demonstrates customization API to define new sparse tensor formats and sparsifier. It shows how to register custom operator and sparsifier implementations for them.

In [None]:
import torch
import sten
import scipy

In [None]:
from native_scripting import compile
import functools
import ctypes
import time
import math
from heapq import nlargest

In [None]:
try:
    cache = functools.cache
except AttributeError:
    cache = functools.lru_cache(maxsize=None)

In [None]:
@cache
def group_n_m2(dense_shape, dense_dtype, n, m, tileM):
    nrows = dense_shape[0]
    ncols = dense_shape[1]

    m_fixed = 4
    A_num_cols_sp_pad = round_up((round_up(ncols, m)/m)*n, 16)
    indexes_cols = A_num_cols_sp_pad//n*m_fixed

    assert dense_dtype in (torch.float32, torch.float64)
    dtype = "float" if dense_dtype == torch.float32 else "double"
    lib = compile(
        f"""
        #include <iostream>
        #include <algorithm>
        #include <utility>
        #include <cstdlib>
        #include <cstdio>
        #include <cmath>
        #include <functional>
        #include <tuple>
        #include <vector>
        #include <numeric>
        #include <chrono>

        using namespace std;

        int int_ceil(int x, int y){{
            return (x - 1) / y + 1;
        }}

        extern "C" void func({dtype}* dense, {dtype}* sparse, int* masks, int *columns){{
            int bm_m   = {nrows}/{tileM};
            int mcol_k = {ncols}/{m};
            int mcol_k_p = int_ceil({ncols},{m});
            int m_fixed = 4;

            std::vector<{dtype}> v(m_fixed, 0);
            std::vector<int> vx(m_fixed, 0);
            {dtype} max=0, total=0;

            std::vector<size_t> indices(v.size());
            std::iota(indices.begin(), indices.end(), 0);

            for(int bm_i=0; bm_i<bm_m; bm_i++){{
                int t_bm_i   = bm_i*{tileM}*{ncols};
                for(int mcol_i=0; mcol_i<mcol_k; mcol_i++){{
                    int t_mcol_i = mcol_i*{m};
                    max = 0;

                    std::vector<int> cols_max;
                    cols_max.resize(m_fixed, 0);
                    std::vector<int> masks_max;
                    masks_max.resize({tileM}*{m}, 0);

                    for(int col_i=0; col_i<{m}; col_i++){{
                        vx[0]=col_i;
                        for(int col_j=col_i+1; col_j<{m}; col_j++){{
                            vx[1]=col_j;
                            for(int col_k=col_j+1; col_k<{m}; col_k++){{
                                vx[2]=col_k;
                                for(int col_w=col_k+1; col_w<{m}; col_w++){{
                                    vx[3]=col_w;
                                    total=0;

                                    std::vector<int> mask({tileM}*{m}, 0);

                                    for(int row_i=0; row_i<{tileM}; row_i++){{
                                        int t_row_i  = row_i*{ncols};
                                        v[0]=dense[t_bm_i + t_row_i + t_mcol_i + col_i];
                                        v[1]=dense[t_bm_i + t_row_i + t_mcol_i + col_j];
                                        v[2]=dense[t_bm_i + t_row_i + t_mcol_i + col_k];
                                        v[3]=dense[t_bm_i + t_row_i + t_mcol_i + col_w];

                                        std::partial_sort(indices.begin(), indices.begin() + {n}, indices.end(), [&](size_t A, size_t B) {{
                                                    return v[A] > v[B];}});

                                        for(int id=0; id<{n}; id++){{
                                            total += dense[t_bm_i + t_row_i + t_mcol_i + vx[indices[id]]];

                                            mask[row_i*{m} + vx[indices[id]]] = 1;
                                        }}
                                    }}

                                    if(total>max){{
                                        max = total;
                                        std::copy(mask.begin(), mask.end(), masks_max.begin());

                                        std::sort(vx.begin(), vx.begin() + m_fixed);
                                        std::copy(vx.begin(), vx.end(), cols_max.begin());
                                    }}
                                }}
                            }}
                        }}
                    }}

                    for(int i=0; i<{tileM}; i++){{
                        for(int j=0; j<{m}; j++){{
                            int drop = masks_max[i*{m} + j];
                            masks[t_bm_i  + i*{ncols}+ t_mcol_i + j]  = drop;
                            sparse[t_bm_i + i*{ncols}+ t_mcol_i + j] *= drop;
                        }}
                    }}
                    for(int i=0; i<m_fixed; i++){{
                        columns[bm_i*{indexes_cols} + mcol_i*m_fixed + i] =
                        cols_max[i];
                    }}
                }}
            }}

            int remainder = {ncols}%{m};

            if (remainder>0){{
                int d_rows={tileM}, d_cols=remainder;

                if(remainder<3){{
                    for(int i=0; i<{nrows}; i++){{
                        for(int j={ncols}-remainder; j<{ncols}; j++){{
                            masks[i*{ncols}+j] = 1;
                        }}
                    }}
                    for(int bm_i=0; bm_i<bm_m; bm_i++){{
                        for(int j=0; j<m_fixed; j++){{
                            columns[bm_i*{indexes_cols} + mcol_k*m_fixed + j] = j;
                        }}
                    }}
                }} else if(remainder==3){{
                    v[3] = -1;
                    for(int bm_i=0; bm_i<bm_m; bm_i++){{
                        int t_bm_i   = bm_i*{tileM}*{ncols};
                        for(int mcol_i=mcol_k; mcol_i<mcol_k_p; mcol_i++){{
                            max = 0;
                            int t_mcol_i = mcol_i*{m};

                            std::vector<int> cols_max(m_fixed, 0);
                            std::vector<int> masks_max({tileM}*remainder, 0);

                            for(int col_i=0; col_i<remainder; col_i++){{
                                vx[0]=col_i;
                                for(int col_j=col_i+1; col_j<remainder; col_j++){{
                                    vx[1]=col_j;
                                    for(int col_k=col_j+1; col_k<remainder; col_k++){{
                                        vx[2]=col_k;
                                        total=0;
                                        std::vector<int> mask({tileM}*remainder, 0);

                                        for(int row_i=0; row_i<{tileM}; row_i++){{
                                            int t_row_i  = row_i*{ncols};
                                            v[0]=dense[t_bm_i + t_row_i + t_mcol_i + col_i];
                                            v[1]=dense[t_bm_i + t_row_i + t_mcol_i + col_j];
                                            v[2]=dense[t_bm_i + t_row_i + t_mcol_i + col_k];

                                            std::partial_sort(indices.begin(), indices.begin() + {n}, indices.end(), [&](size_t A, size_t B) {{
                                                        return v[A] > v[B]; }});

                                            for(int id=0; id<{n}; id++){{
                                                total += dense[t_bm_i + t_row_i + t_mcol_i + vx[indices[id]]];

                                                mask[row_i*remainder + vx[indices[id]]] = 1;
                                            }}
                                        }}

                                        if(total>max){{
                                            max = total;
                                            std::copy(mask.begin(), mask.end(), masks_max.begin());

                                            std::sort(vx.begin(), vx.begin() + remainder);//m_fixed
                                            std::copy(vx.begin(), vx.end(), cols_max.begin());
                                        }}
                                    }}
                                }}
                            }}

                            for(int i=0; i<{tileM}; i++){{
                                for(int j=0; j<remainder; j++){{
                                    int drop = masks_max[i*remainder + j];

                                    masks[t_bm_i  + i*{ncols}+ t_mcol_i + j]  = drop;
                                    sparse[t_bm_i + i*{ncols}+ t_mcol_i + j] *= drop;
                                }}
                            }}
                            for(int i=0; i<remainder; i++){{
                                columns[bm_i*{indexes_cols} + mcol_i*m_fixed + i] =
                                cols_max[i];
                            }}
                        }}
                    }}
                }} else {{
                    for(int bm_i=0; bm_i<bm_m; bm_i++){{
                        int t_bm_i   = bm_i*{tileM}*{ncols};
                        for(int mcol_i=mcol_k; mcol_i<mcol_k_p; mcol_i++){{
                            max = 0;
                            int t_mcol_i = mcol_i*{m};

                            std::vector<int> cols_max(m_fixed, 0);
                            std::vector<int> masks_max({tileM}*remainder, 0);

                            for(int col_i=0; col_i<remainder; col_i++){{
                                vx[0]=col_i;
                                for(int col_j=col_i+1; col_j<remainder; col_j++){{
                                    vx[1]=col_j;
                                    for(int col_k=col_j+1; col_k<remainder; col_k++){{
                                        vx[2]=col_k;
                                        for(int col_w=col_k+1; col_w<remainder; col_w++){{
                                            vx[3]=col_w;
                                            total=0;
                                            std::vector<int> mask({tileM}*remainder, 0);

                                            for(int row_i=0; row_i<{tileM}; row_i++){{
                                                int t_row_i  = row_i*{ncols};
                                                v[0]=dense[t_bm_i + t_row_i + t_mcol_i + col_i];
                                                v[1]=dense[t_bm_i + t_row_i + t_mcol_i + col_j];
                                                v[2]=dense[t_bm_i + t_row_i + t_mcol_i + col_k];
                                                v[3]=dense[t_bm_i + t_row_i + t_mcol_i + col_w];

                                                std::partial_sort(indices.begin(), indices.begin() + {n}, indices.end(), [&](size_t A, size_t B) {{
                                                            return v[A] > v[B]; }});

                                                for(int id=0; id<{n}; id++){{
                                                    total += dense[t_bm_i + t_row_i + t_mcol_i + vx[indices[id]]];

                                                    mask[row_i*remainder + vx[indices[id]]] = 1;
                                                }}
                                            }}

                                            if(total>max){{
                                                max = total;
                                                std::copy(mask.begin(), mask.end(), masks_max.begin());

                                                std::sort(vx.begin(), vx.begin() + m_fixed);
                                                std::copy(vx.begin(), vx.end(), cols_max.begin());
                                            }}
                                        }}
                                    }}
                                }}
                            }}

                            for(int i=0; i<{tileM}; i++){{
                                for(int j=0; j<remainder; j++){{
                                    int drop = masks_max[i*remainder + j];

                                    masks[t_bm_i  + i*{ncols}+ t_mcol_i + j]  = drop;
                                    sparse[t_bm_i + i*{ncols}+ t_mcol_i + j] *= drop;
                                }}
                            }}
                            for(int i=0; i<m_fixed; i++){{
                                columns[bm_i*{indexes_cols} + mcol_i*m_fixed + i] =
                                cols_max[i];
                            }}
                        }}
                    }}
                }}
            }}
        }}
        """,
    )
    lib.func.argtypes = [
        ctypes.c_void_p,
        ctypes.c_void_p,
        ctypes.c_void_p,
        ctypes.c_void_p,
    ]
    return lib.func

In [None]:
@cache
def to_dense(dense_shape, dense_dtype, n, m, tileM):
    nrows = dense_shape[0]
    ncols = dense_shape[1]

    A_size = nrows*ncols
    density = n/m

    brow = 4 #this->brow = brow_;
    mbrow = 32 #this->mbrow = mbrow_;

    bm   = tileM
    # !IMPORTANT! constants because of architecture constraints
    m_fixed = 4
    bits_elem_meta=2
    mrow_m = 2
    bits_elem_cols=8
    brow_fixed = 16
    nelems=32//bits_elem_meta #(sizeof(uint)*8)=32
    nelems_col = nelems//mrow_m

    A_num_cols_sp = (ncols/m)*n
    A_num_cols_sp_pad_nm = (round_up(ncols, m)/m)*n
    A_num_cols_sp_pad = round_up((round_up(ncols, m)/m)*n, 16)
    A_nnz = nrows*A_num_cols_sp_pad

    assert dense_dtype in (torch.float32, torch.float64)
    dtype = "float" if dense_dtype == torch.float32 else "double"
    lib = compile(
        f"""
        #include <iostream>
        #include <algorithm>
        #include <utility>
        #include <cstdlib>
        #include <cstdio>
        #include <cmath>
        #include <functional>
        #include <tuple>
        #include <vector>
        #include <numeric>
        #include <chrono>

        using namespace std;


        extern "C" void func3({dtype}* hA_dense, {dtype}* hA_values, int *hA_columns, int *hA_metadata){{
            //this->hA_dense.resize(this->A_size, 0);

            // general variables N:M format
            int bm_m = {nrows}/{bm};
            int mbrow_m = {bm}/{mbrow};
            int mbrow_m2 = {mbrow}/{brow_fixed};
            int brow_m = {brow_fixed}/{brow};
            // metadata
            int mcol_kk = {nelems}/{mrow_m}/{n};
            int mcol_k = {A_num_cols_sp_pad}/{n}/mcol_kk;
            // indices
            int col_kk = mcol_kk;
            int col_k = {A_num_cols_sp_pad}/{n}/col_kk;

            uint indexes[{nelems}];
            uint columns[col_kk*{m_fixed}];

            for(int bm_i=0; bm_i<bm_m; bm_i++){{
                for(int mbrow_i=0; mbrow_i<mbrow_m; mbrow_i++){{
                    for(int mbrow_i2=0; mbrow_i2<mbrow_m2; mbrow_i2++){{
                        for(int brow_i=0; brow_i<brow_m; brow_i++){{
                            for(int mcol_i=0; mcol_i<mcol_k; mcol_i++){{
                                //read columns indexes
                                for(int col_i=0; col_i<col_kk; col_i++){{
                                    for(int col_ii=0; col_ii<{m_fixed}; col_ii++){{
                                        columns[col_i*{m_fixed} + col_ii] =
                                        hA_columns[bm_i*col_k*col_kk*{m_fixed} + mcol_i*col_kk*{m_fixed} + col_i*{m_fixed} + col_ii];
                                    }}
                                }}
                                // read metadata
                                for(int mbrow_ii=0; mbrow_ii<({brow}/{mrow_m}); mbrow_ii++){{
                                    for(int mbrow_iii=0; mbrow_iii<{mrow_m}; mbrow_iii++){{
                                        for(int mcol_ii=0; mcol_ii<mcol_kk; mcol_ii++){{
                                            for (int n_i=0; n_i<{n}; n_i++) {{
                                                indexes[
                                                    mbrow_iii*{n} +
                                                    mcol_ii*{mrow_m}*{n} +
                                                    n_i] =
                                                (((hA_metadata[
                                                    bm_i*mcol_k*{bm}/{mrow_m} +
                                                    mbrow_i*mcol_k*{mbrow}/{mrow_m} +
                                                    mbrow_i2*{brow_fixed}/{mrow_m} +
                                                    brow_i*{brow}/{mrow_m}  +
                                                    mcol_i*{mbrow}/{mrow_m} +
                                                    mbrow_ii]) >> (mbrow_iii*({nelems}/{mrow_m})*{bits_elem_meta}+mcol_ii*{n}*{bits_elem_meta}+n_i*{bits_elem_meta})) & 0x3);
                                            }}
                                        }}
                                    }}

                                    for(int mcol_ii=0; mcol_ii<mcol_kk; mcol_ii++){{
                                        for(int mbrow_iii=0; mbrow_iii<{mrow_m}; mbrow_iii++){{
                                            for(int n_i=0; n_i<{n}; n_i++){{
                                                unsigned int index = columns[mcol_ii*{m_fixed} + indexes[mcol_ii*{mrow_m}*{n}+mbrow_iii*{n}+n_i]];

                                                if((mcol_i*{m}*mcol_kk + mcol_ii*{m} + index) < {ncols}){{
                                                    hA_dense[
                                                        bm_i*{bm}*{ncols} +
                                                        mbrow_i*{mbrow}*{ncols} +
                                                        mbrow_i2*{brow_fixed}*{ncols} +
                                                        brow_i*{brow}*{ncols} +
                                                        mcol_i*{m}*mcol_kk +
                                                        mbrow_ii*{mrow_m}*{ncols} +
                                                        mcol_ii*{m} +
                                                        mbrow_iii*{ncols} +
                                                        index] =
                                                    hA_values[
                                                        bm_i*{bm}*{A_num_cols_sp_pad} +
                                                        mbrow_i*{mbrow}*{A_num_cols_sp_pad}+
                                                        mbrow_i2*{brow_fixed}*{A_num_cols_sp_pad}+
                                                        brow_i*{brow}*{nelems}/{mrow_m}+
                                                        mcol_i*{brow_fixed}*{nelems}/{mrow_m} +
                                                        mbrow_ii*{mrow_m}*{n} +
                                                        mcol_ii*{n}*{brow} +
                                                        mbrow_iii*{n} +
                                                        n_i];
                                                }}
                                            }}
                                        }}
                                    }}
                                }}
                            }}
                        }}
                    }}
                }}
            }}
        }}
        """,
    )
    lib.func3.argtypes = [
        ctypes.c_void_p,
        ctypes.c_void_p,
        ctypes.c_void_p,
        ctypes.c_void_p,
    ]
    return lib.func3

In [None]:
@cache
def to_sparse_sr_nm(dense_shape, dense_dtype, n, m, tileM):
    nrows = dense_shape[0]
    ncols = dense_shape[1]

    brow = 4 #this->brow = brow_;
    mbrow = 32 #this->mbrow = mbrow_;

    bm   = tileM
    # !IMPORTANT! constants because of architecture constraints
    m_fixed = 4
    bits_elem_meta=2
    mrow_m = 2
    bits_elem_cols=8
    brow_fixed = 16
    nelems=32//bits_elem_meta #(sizeof(uint)*8)=32
    nelems_col = nelems//mrow_m

    A_num_cols_sp = (ncols//m)*n
    A_num_cols_sp_pad_nm = (round_up(ncols, m)/m)*n
    A_num_cols_sp_pad = round_up((round_up(ncols, m)/m)*n, 16)
    A_nnz = nrows*A_num_cols_sp_pad

    assert dense_dtype in (torch.float32, torch.float64)
    dtype = "float" if dense_dtype == torch.float32 else "double"
    lib = compile(
        f"""
        #include <iostream>
        #include <algorithm>
        #include <utility>
        #include <cstdlib>
        #include <cstdio>
        #include <cmath>
        #include <functional>
        #include <tuple>
        #include <vector>
        #include <numeric>
        #include <chrono>

        using namespace std;


        extern "C" void func2({dtype}* sparse, int* masks, {dtype}* hA_values, int *hA_columns, int *hA_metadata){{

            int bm_m = {nrows}/{bm};
            int mbrow_m = {bm}/{mbrow};
            int mbrow_m2 = {mbrow}/{brow_fixed};
            int brow_m = {brow_fixed}/{brow};
            // metadata
            int mcol_kk = {nelems}/{mrow_m}/{n};
            int mcol_k = {A_num_cols_sp_pad}/{n}/mcol_kk;
            // indices
            int col_kk = mcol_kk;
            int col_k = {A_num_cols_sp_pad}/{n}/col_kk;

            {dtype} values[{nelems}];
            uint indexes[{nelems}];
            uint columns[col_kk*{m_fixed}];

            int max_idx = 0;

            for(int bm_i=0; bm_i<bm_m; bm_i++){{
                for(int mbrow_i=0; mbrow_i<mbrow_m; mbrow_i++){{
                    for(int mbrow_i2=0; mbrow_i2<mbrow_m2; mbrow_i2++){{
                        for(int brow_i=0; brow_i<brow_m; brow_i++){{
                            for(int mcol_i=0; mcol_i<mcol_k; mcol_i++){{
                                for(int col_i=0; col_i<col_kk; col_i++){{
                                    for(int col_ii=0; col_ii<{m_fixed}; col_ii++){{
                                        columns[col_i*{m_fixed} + col_ii] =
                                        hA_columns[bm_i*col_k*col_kk*{m_fixed} + mcol_i*col_kk*{m_fixed} + col_i*{m_fixed} + col_ii];
                                    }}
                                }}
                                for(int mbrow_ii=0; mbrow_ii<({brow}/{mrow_m}); mbrow_ii++){{
                                    for(int mcol_ii=0; mcol_ii<mcol_kk; mcol_ii++){{
                                        for(int mbrow_iii=0; mbrow_iii<{mrow_m}; mbrow_iii++){{
                                            int pos=0;
                                            for(int n_i=0; n_i<{m_fixed}; n_i++){{
                                                unsigned int index = columns[mcol_ii*{m_fixed} + n_i];

                                                if((mcol_i*{m}*mcol_kk + mcol_ii*{m} + index) < {ncols}){{
                                                    int nnz = masks[
                                                            bm_i*{bm}*{ncols} +
                                                            mbrow_i*{mbrow}*{ncols} +
                                                            mbrow_i2*{brow_fixed}*{ncols} +
                                                            brow_i*{brow}*{ncols} +
                                                            mcol_i*{m}*mcol_kk +
                                                            mbrow_ii*{mrow_m}*{ncols} +
                                                            mcol_ii*{m} +
                                                            mbrow_iii*{ncols} +
                                                            index];

                                                    if(nnz != 0){{
                                                        indexes[
                                                            mbrow_iii*{n} +
                                                            mcol_ii*{mrow_m}*{n} +
                                                            pos] = n_i;

                                                        values[
                                                            mcol_ii*{mrow_m}*{n} +
                                                            mbrow_iii*{n} +
                                                            pos] =
                                                        sparse[
                                                            bm_i*{bm}*{ncols} +
                                                            mbrow_i*{mbrow}*{ncols} +
                                                            mbrow_i2*{brow_fixed}*{ncols} +
                                                            brow_i*{brow}*{ncols} +
                                                            mcol_i*{m}*mcol_kk +
                                                            mbrow_ii*{mrow_m}*{ncols} +
                                                            mcol_ii*{m} +
                                                            mbrow_iii*{ncols} +
                                                            index];

                                                        pos+=1;
                                                    }}
                                                }} else {{
                                                    if(n_i<2){{
                                                        indexes[
                                                            mbrow_iii*{n} +
                                                            mcol_ii*{mrow_m}*{n} +
                                                            pos] = 0;

                                                        values[
                                                            mcol_ii*{mrow_m}*{n} +
                                                            mbrow_iii*{n} +
                                                            pos] = 0;

                                                        pos+=1;
                                                    }}
                                                }}
                                            }}
                                        }}
                                    }}
                                    // write metadata
                                    unsigned int meta=0;
                                    for(int mbrow_iii=0; mbrow_iii<{mrow_m}; mbrow_iii++){{
                                        for(int mcol_ii=0; mcol_ii<mcol_kk; mcol_ii++){{
                                            for (int n_i=0; n_i<{n}; n_i++) {{

                                                int idx = bm_i*{bm}*{A_num_cols_sp_pad} +
                                                        mbrow_i*{mbrow}*{A_num_cols_sp_pad}+
                                                        mbrow_i2*{brow_fixed}*{A_num_cols_sp_pad}+
                                                        brow_i*{brow}*{nelems}/{mrow_m}+
                                                        mcol_i*{brow_fixed}*{nelems}/{mrow_m} +
                                                        mbrow_ii*{mrow_m}*{n} +
                                                        mcol_ii*{n}*{brow} +
                                                        mbrow_iii*{n} +
                                                        n_i;

                                                max_idx = (idx>max_idx)?(idx):(max_idx);

                                                hA_values[
                                                        idx] =
                                                values[
                                                    mcol_ii*{mrow_m}*{n} +
                                                    mbrow_iii*{n} +
                                                    n_i];

                                                unsigned int tmp = indexes[
                                                            mbrow_iii*{n} +
                                                            mcol_ii*{mrow_m}*{n} +
                                                            n_i];
                                                meta |= (tmp << (mbrow_iii*({nelems}/{mrow_m})*{bits_elem_meta}+mcol_ii*{n}*{bits_elem_meta}+n_i*{bits_elem_meta}));
                                            }}
                                        }}
                                    }}
                                    hA_metadata[bm_i*mcol_k*{bm}/{mrow_m} +
                                                mbrow_i*mcol_k*{mbrow}/{mrow_m} +
                                                mbrow_i2*{brow_fixed}/{mrow_m} +
                                                brow_i*{brow}/{mrow_m}  +
                                                mcol_i*{mbrow}/{mrow_m} +
                                                mbrow_ii] = meta;
                                }}
                            }}
                        }}
                    }}
                }}
            }}
            cout << "max_idx: " << max_idx << endl;
        }}
        """,
    )
    lib.func2.argtypes = [
        ctypes.c_void_p,
        ctypes.c_void_p,
        ctypes.c_void_p,
        ctypes.c_void_p,
        ctypes.c_void_p,
    ]
    return lib.func2

In [None]:
def round_up(x,y):
    return math.ceil(x/y)*y

In [None]:
a = torch.randn(768, 768, requires_grad=True)
b = torch.randn(768, 768, requires_grad=True)
#b = torch.ones(64, 64, requires_grad=True)
c = torch.ones(768, 768, requires_grad=True)
grad_d = torch.randn(768, 768)

d = torch.mm(torch.add(a, b), c)
d.backward(grad_d)

In [None]:
class SrNMTensor:
    def __init__(self, n_, m_, tileM_, dense_, mask_, columns_):
        self.n = n_
        self.m = m_
        self.tileM = tileM_
        self.nnz = 0
        self.nrows = None
        self.ncols = None
        #self.dense = dense_
        #self.mask = mask_
        self.columns = columns_        
        self.data = None
        self.metadata = None       
        self.to_sparse_sr_nm(dense_, mask_) 
        #self.dense = None
    
    def to_sparse_sr_nm(self, dense_, mask_):
        impl_builder = (
            to_sparse_sr_nm
            )
        func = impl_builder(
                dense_.shape,
                dense_.dtype,
                self.n,
                self.m,
                self.tileM
            )
        
        self.nrows, self.ncols = dense_.shape
        A_num_cols_sp_pad = round_up((round_up(self.ncols, self.m)/self.m)*self.n, 16)
        self.nnz = self.nrows*A_num_cols_sp_pad
        m_fixed = 4
        mrow_m = 2
        bits_elem_meta=2

        nelems = 32//bits_elem_meta #32=(sizeof(uint)*8)
        nelems_col = nelems//mrow_m

        self.values = torch.zeros(self.nrows * A_num_cols_sp_pad, dtype=dense_.dtype)
        self.metadata = torch.zeros(self.nrows//mrow_m * A_num_cols_sp_pad//nelems_col, dtype=torch.int32)

        func(dense_.data_ptr(), mask_.data_ptr(), self.values.data_ptr(), self.columns.data_ptr(), self.metadata.data_ptr())     

    def to_dense(self):
        impl_builder = (
            to_dense
            )
        func = impl_builder(
                (self.nrows, self.ncols),
                self.values.dtype,
                self.n,
                self.m,
                self.tileM
            )
        dense = torch.zeros((self.nrows, self.ncols), dtype=self.values.dtype)
        func(dense.data_ptr(), self.values.data_ptr(), self.columns.data_ptr(), self.metadata.data_ptr())

        return dense

In [None]:
class NMVectorSparsifier:
    def __init__(self, n, m, tileM):
        self.n = n
        self.m = m
        self.tileM = tileM
    

In [None]:
def nm_vector_mask_sparsify(tensor, n, m, tileM):
    print("nm_vector_mask_sparsify", n, m, tileM)

    impl_builder = (
                group_n_m2
                )
    nrows, ncols = tensor.shape
    A_num_cols_sp_pad = round_up((round_up(ncols, m)/m)*n, 16)            
    bm_m   = nrows//tileM
    mcol_k_p = math.ceil(ncols/m)
    m_fixed = 4
    
    # Structures represent sparse data
    masks = torch.zeros(tensor.shape, dtype=torch.int32)    
    columns = torch.zeros(nrows//tileM * A_num_cols_sp_pad//n*m_fixed,dtype=torch.int32)
    if len(tensor.shape) == 2:
        tensor_temp = tensor.cpu().detach().abs()
        sparse = tensor_temp.clone()        

        func = impl_builder(
                    tensor_temp.shape,
                    tensor_temp.dtype,
                    n,
                    m,
                    tileM
                )
        func(tensor_temp.data_ptr(), sparse.data_ptr(), masks.data_ptr(), columns.data_ptr())

    else:
        raise NotImplementedError("Only support layers of dimension 2 or 4")

    #self.columns = columns
    #return masks
    return masks, columns

In [None]:
@sten.register_sparsifier_implementation(
    sparsifier=NMVectorSparsifier, inp=torch.Tensor, out=SrNMTensor
)
def torch_tensor_to_srnm_random_fraction(sparsifier, tensor, grad_fmt=None):
    print("inside NMVectorSparsifier sparsifier")
    masks, columns = nm_vector_mask_sparsify(tensor, sparsifier.n, sparsifier.m, sparsifier.tileM)
    return sten.SparseTensorWrapper.wrapped_from_dense(
        SrNMTensor(sparsifier.n, sparsifier.m, sparsifier.tileM, tensor, masks, columns),
        tensor,
        grad_fmt,
    )

In [None]:
n=2; m=4; tileM=128
BM=32
BN=32
BK=32
WM=32
WN=32
WK=32
MM=16
MN=8
MK=32
NSTAGE=2

sparse_add = sten.sparsified_op(
    orig_op=torch.add,
    out_fmt=(
        (sten.KeepAll(), torch.Tensor,
         NMVectorSparsifier(n,m,tileM), SrNMTensor),
    ),
    grad_out_fmt=(
        (sten.KeepAll(), torch.Tensor,
         NMVectorSparsifier(n,m,tileM), SrNMTensor),
    ),
)

In [None]:
print(torch.add(a, b))

Then we try to use the operator.

In [None]:
srnm = sparse_add(a, b)

In [None]:
srnm.wrapped_tensor.to_dense()

In [None]:
print( torch.equal(srnm.wrapped_tensor.to_dense(), torch.add(a, b)) )

In [None]:
print( srnm.wrapped_tensor.values.cuda() )

In [None]:
import spatha

In [None]:
@sten.register_fwd_op_impl(
    operator=torch.mm,
    inp=(SrNMTensor, torch.Tensor),
    out=[(sten.KeepAll, torch.Tensor)],
)
def sparse_torch_add_fwd_impl(ctx, inputs, output_sparsifiers):
    """ input, other = inputs
    [out_sp] = output_sparsifiers
    dense_out = torch.add(
        input.wrapped_tensor.to_dense(),
        other.wrapped_tensor.to_dense(),
    )
    return torch_tensor_to_srnm_random_fraction(
        KeepAll(), nm_vector_mask_sparsify(dense_out, out_sp.n, out_sp.m, out_sp.tileM)
    ) """

    input1, input2 = inputs
    ctx.save_for_backward(input1, input2)
    #output = torch.from_numpy(input1.wrapped_tensor.data @ input2.numpy())
    
    #bias = torch.zeros((input1.wrapped_tensor.nrows, input2.shape[1]))
    bias = torch.ones(input1.wrapped_tensor.nrows)*2
    output = spatha.spmm(input1.wrapped_tensor.metadata.cuda(), # metadata
                          input1.wrapped_tensor.columns.cuda(),  # indices
                          input1.wrapped_tensor.values.to(dtype=torch.half).cuda(),                                    # values
                          input2.to(dtype=torch.half).cuda(),    # rhs_matrix
                          bias.to(dtype=torch.half).cuda(),
                          input1.wrapped_tensor.nrows,           # A_num_rows
                          input1.wrapped_tensor.ncols,           # A_num_cols
                          input2.shape[1],                       # B_num_cols
                          input1.wrapped_tensor.tileM,           # vec_length
                          input1.wrapped_tensor.n,               # n
                          input1.wrapped_tensor.m,               # m
                          input1.wrapped_tensor.nnz,             # nnz
                          0,                                     # seed
                          32,                                    # mbrow
                          4                                     # brow
                          )
    return output

#a1 = sparse_add(a, b).to(dtype=torch.half).to(device="cuda:0")
#c = c.to(dtype=torch.half).to(device="cuda:0")

d = torch.mm(sparse_add(a, b), c)

#print(d)

In [None]:
e = torch.from_numpy(srnm.wrapped_tensor.to_dense().to(dtype=torch.half).detach().numpy() @ c.detach().numpy()).to(device="cuda:0").to(dtype=torch.half)
#e = torch.from_numpy(a.detach().numpy() @ c.detach().numpy()).to(device="cuda:0")
print(e)

In [None]:
bias = torch.ones((e.shape))*2
e+=bias.to(dtype=torch.half).cuda()

In [None]:
print( torch.equal(d.T,e) )

In [None]:
print( torch.allclose(d.T,e) )

In [None]:
torch.any(d.T != e, axis=1).nonzero().flatten()
print(d.T[1])
print(e[1])

In [None]:
print(d.values)