In [3]:
import torch
import diopilib
from conformance.diopi_functions import spmm
from conformance.diopi_runtime import Tensor
import numpy as np
from torch.nn.functional import dropout
# sparse matrix:
# 0.5,  0,  1
# 0,    0,  2
# 1,    3,  0.6

M, K, N = 4096, 4096, 128

a = torch.randn((M,K),dtype=torch.float32)
a = dropout(a, p=0.9)
sparse_a = a.to_sparse_csr()
print(a)
print(sparse_a)
b = np.random.randn(K,N).astype(np.float32)
print(b)

input = Tensor.from_numpy(b)
row_ptr = Tensor.from_numpy(sparse_a.crow_indices().numpy().astype(np.int32))
col_ind = Tensor.from_numpy(sparse_a.col_indices().numpy().astype(np.int32))
values = Tensor.from_numpy(sparse_a.values().numpy().astype(np.float32))
print(list(input.size().data))
c = spmm(row_ptr, col_ind, values, input)
c_ref = a @ torch.from_numpy(b)
# c = rbrmsr_spmm(sparse_a.crow_indices(), sparse_a.col_indices(), sparse_a.values(), b)


tensor([[ 0.0000, -0.0000, -0.0000,  ..., -0.0000,  0.0000,  0.0000],
        [-0.0000,  0.0000, -0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [-0.0000,  0.0000, -0.0000,  ...,  0.0000, -0.0000,  0.0000],
        ...,
        [ 0.0000, 23.0290,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [-3.7184, -0.0000, -0.0000,  ..., -0.0000, -0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.0000,  ..., -0.0000, -4.7012,  0.0000]])
tensor(crow_indices=tensor([      0,     426,     850,  ..., 1676665,
                            1677043, 1677436]),
       col_indices=tensor([  25,   40,   72,  ..., 4071, 4075, 4094]),
       values=tensor([11.1403, -1.7448,  5.9189,  ...,  0.0816, 10.1160,
                      -4.7012]), size=(4096, 4096), nnz=1677436,
       layout=torch.sparse_csr)
[[-0.12721 -0.49293 -1.36716 ...  0.11863  0.81915  0.1262 ]
 [-2.30494 -0.50631 -0.23826 ... -0.88652  0.44469  0.43873]
 [ 1.19146  0.83548 -1.03986 ... -0.18479 -1.14953  0.54477]
 ...
 [ 0.43145  1.52025

In [4]:
c_ours = torch.from_numpy(c.numpy())
print("ours:", c_ours)
print("ref: ", c_ref)
assert torch.allclose(c_ours, c_ref, rtol=1e-03, atol=1e-03)

ours: tensor([[ 189.8907,  -14.9602,  114.9101,  ...,   84.5701,   -9.3818,
          367.1490],
        [ -49.5185,  155.6731, -420.4368,  ...,  107.9797,  367.4091,
           -8.2983],
        [ 100.9162,   10.6236,  311.1810,  ...,  -59.7648,  264.6500,
          -59.7523],
        ...,
        [-231.3855,  114.1076, -277.9331,  ..., -188.6128,   57.9538,
         -185.5934],
        [  57.2971,  128.9449,  -58.4205,  ...,  -96.2650,  -76.2732,
          210.6950],
        [-113.7153,   34.8387,   90.0678,  ...,  -45.6139,   31.5837,
         -257.7583]])
ref:  tensor([[ 189.8907,  -14.9602,  114.9101,  ...,   84.5701,   -9.3818,
          367.1489],
        [ -49.5186,  155.6732, -420.4370,  ...,  107.9796,  367.4092,
           -8.2982],
        [ 100.9163,   10.6236,  311.1810,  ...,  -59.7648,  264.6500,
          -59.7523],
        ...,
        [-231.3856,  114.1077, -277.9332,  ..., -188.6128,   57.9537,
         -185.5935],
        [  57.2971,  128.9448,  -58.4205,  ...,  -9