In [1]:
import statistics
import subprocess
import ctypes

import sten

import math
import numpy as np
import torch
from torch.profiler import profile, record_function, ProfilerActivity

from pathlib import Path

import timeit
import sys
import time

from grouped_nmv_tensor import SrNMTensor, nm_vector_mask_sparsify

import spatha

  return self.fget.__get__(instance, owner)()
  value = getter(object, key)


In [2]:
v = 64
m = 16
n = 2

In [3]:
torch. set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fcddd2f0890>

# 1.Dense computation implementation with torch.matmul()

>Assuming the original matrix W has a shape of [2048 * 4096].

>Assuming the original matrix X has a shape of [32768 * 2048].

>Then, the shape of the matrix multiplication Y = XW is [32768 * 4096].

![pic1](For_the_example_notebook.png "example")

After transposing the original matrix, a pruned_matrix is obtained by structured pruning according to the v:n:m scheme.

The logic is as follows: Traverse each VxM block, first select 4 columns where non-zero elements can exist, and then randomly choose two columns from these four to fill in non-zero elements for each row.

In [4]:
def vnm_random_pruning_torch(matrix, V, N, M):
    """
    Perform v:n:m random structured pruning on a given matrix.

    This function applies structured pruning to a matrix based on a specified VxM block configuration.
    Within each block, 4 columns are randomly chosen as potential non-zero columns.
    Then, for each row within the block, N out of these 4 columns are selected to retain non-zero elements from the original matrix.

    Parameters:
    - matrix (torch.Tensor): The original matrix to be pruned.
    - V (int): Number of rows in each block.
    - N (int): Number of columns to select for non-zero elements in each row.
    - M (int): Total number of columns in each block.

    Returns:
    - torch.Tensor: A new matrix of the same shape as the input, with elements pruned according to the v:n:m scheme.
    """
    rows, cols = matrix.shape
    pruned_matrix = torch.zeros_like(matrix)

    # Traverse each VxM block
    for row_block in range(0, rows, V):
        for col_block in range(0, cols, M):
            # Randomly select 4 columns in each block as possible non-zero columns
            possible_cols = torch.randperm(M)[:4] + col_block

            # For each row, randomly select N columns from these 4 columns to fill in non-zero elements
            for v_row in range(V):
                selected_cols = possible_cols[torch.randperm(4)[:N]]
                pruned_matrix[row_block + v_row, selected_cols] = matrix[row_block + v_row, selected_cols]

    return pruned_matrix

In [5]:
matrix = torch.rand(4096, 2048, device="cuda:0", dtype=torch.float16)

torch.manual_seed(0)
pruned_matrix = vnm_random_pruning_torch(matrix, v, n, m)
pruned_matrix = pruned_matrix.to("cuda:0")

In [6]:
# save the pruned matrix
torch.save(pruned_matrix, "pruned_matrix.pt")

In [7]:
pruned_matrix.size()

torch.Size([4096, 2048])

In [8]:
# So the original matrix w is the transpose of pruned_matrix
w = pruned_matrix.T
w.size()

torch.Size([2048, 4096])

In [9]:
x = torch.randn(size=(32768 ,2048), dtype=torch.float16, device='cuda:0')
x.size()

torch.Size([32768, 2048])

In [10]:
y = torch.matmul(x, w)
y.size()

torch.Size([32768, 4096])

# 2.Perform sparse matrix computations using the spatha library

In [11]:
class NMVectorSparsifier:
    def __init__(self, n, m, tileM):
        self.n = n
        self.m = m
        self.tileM = tileM

    def __call__(self, tensor, grad_fmt=None):

        mask, columns = nm_vector_mask_sparsify(tensor, self.n, self.m, self.tileM)
        
        sparse_mtx = sten.SparseTensorWrapper.wrapped_from_dense(
            SrNMTensor(self.n, self.m, self.tileM, tensor, mask, columns, tensor.device),
            tensor,
            grad_fmt,
        )

        return sparse_mtx

In [12]:
def sparse_dense_mul_dispatch(sparse_values, sparse_indices, sparse_metadata, dense, nrows_sp, ncols_sp, ncols_d, m, n, v, nnz):

    dense_ = dense.contiguous()
    #can not accept bias currently
    bias = bias = torch.zeros(nrows_sp, dtype=torch.float16, device='cuda:0')
    output = spatha.spmm(sparse_metadata,  # metadata
                          sparse_indices,   # indices
                          sparse_values,    # values
                          dense_,           # rhs_matrix
                          bias,
                          nrows_sp,         # A_num_rows
                          ncols_sp,         # A_num_cols
                          ncols_d,          # B_num_cols
                          v,                # vec_length
                          n,                # n
                          m,                # m
                          nnz,              # nnz
                          0,                # seed
                          32,               # mbrow
                          4                 # brow
                          )

    return output

In [13]:
w_transpose = NMVectorSparsifier(n, m, v)(pruned_matrix).wrapped_tensor

In [14]:
values = torch.nn.Parameter(w_transpose.values)
columns = w_transpose.columns
metadata = w_transpose.metadata
nrows_sp = w_transpose.nrows
ncols_sp = w_transpose.ncols
nnz      = w_transpose.nnz
ncols_d = x.shape[0]

In [15]:
output = sparse_dense_mul_dispatch(values, columns, metadata, x.T, nrows_sp, ncols_sp,
                                           ncols_d, m, n, v, nnz)

In [16]:
output.size()

torch.Size([32768, 4096])

# 3 Check if output and y are equal

In [17]:
torch.allclose(output, y, atol=0.1)

True