In [2]:
import builtins

from isplib.matmul import *
from isplib.tensor import SparseTensor
from scipy.sparse import coo_matrix
import torch 

index = torch.tensor([[0, 0, 1, 2, 2],
                      [0, 2, 1, 0, 1]])
value = torch.Tensor([1, 2, 4, 1, 3])
matrix = torch.Tensor([[90, 4], [2, 5], [3, 6]])

a = SparseTensor.from_scipy(coo_matrix((value, index), shape=(3, 3)))
b = matrix
builtins.FUSEDMM = True
spmm_sum(a, b)

Using FusedMM SpMM...


tensor([[96., 16.],
        [ 8., 20.],
        [96., 19.]])

In [3]:
import builtins
builtins.FUSEDMM = True


import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import sklearn.metrics as metrics

from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
dataset = Planetoid("Planetoid", name="Cora", transform=T.ToSparseTensor())

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import SAGEConv
from isplib.tensor import SparseTensor

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 16, cached=True)
        self.conv2 = GCNConv(16, dataset.num_classes, cached=True)

    def forward(self, data):
        x, adj_t = data.x, data.adj_t
        x = self.conv1(x, adj_t)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, adj_t)

        return F.log_softmax(x, dim=1)



In [29]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
data = dataset[0].to(device)

def train_GCN():
  builtins.FUSEDMM = True
  optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

  model.train()
  for epoch in range(10):
      optimizer.zero_grad()
      out = model(data)
      loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
      loss.backward()
      optimizer.step()
      
      _, pred = model(data).max(dim=1)
      correct = float (pred[data.train_mask].eq(data.y[data.train_mask]).sum().item())
      acc = correct / data.train_mask.sum().item()
      print('Epoch: %d, Accuracy: %.4f'%(epoch,acc))

def test_GCN(fusedmm):
  builtins.FUSEDMM = fusedmm  # Use FusedMM or not
  _, pred = model(data).max(dim=1)
  correct = float (pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
  acc = correct / data.test_mask.sum().item()
  # print('Accuracy: {:.4f}'.format(acc))
  return acc

In [30]:
import cProfile, pstats
from pstats import SortKey

# https://gist.github.com/romuald/0346c76cfbbbceb3e4d1

def f8(x):
    ret = "%8.6f" % x
    if ret != '   0.000':
        return ret
    return "%6dµs" % (x * 1000000)

pstats.f8 = f8


In [31]:
train_GCN()

Using FusedMM SpMM...
Using FusedMM SpMM...
Using FusedMM SpMM...
Using FusedMM SpMM...
Epoch: 0, Accuracy: 0.5000
Using FusedMM SpMM...
Using FusedMM SpMM...
Using FusedMM SpMM...
Using FusedMM SpMM...
Epoch: 1, Accuracy: 0.6857
Using FusedMM SpMM...
Using FusedMM SpMM...
Using FusedMM SpMM...
Using FusedMM SpMM...
Epoch: 2, Accuracy: 0.7071
Using FusedMM SpMM...
Using FusedMM SpMM...
Using FusedMM SpMM...
Using FusedMM SpMM...
Epoch: 3, Accuracy: 0.7786
Using FusedMM SpMM...
Using FusedMM SpMM...
Using FusedMM SpMM...
Using FusedMM SpMM...
Epoch: 4, Accuracy: 0.8643
Using FusedMM SpMM...
Using FusedMM SpMM...
Using FusedMM SpMM...
Using FusedMM SpMM...
Epoch: 5, Accuracy: 0.8429
Using FusedMM SpMM...
Using FusedMM SpMM...
Using FusedMM SpMM...
Using FusedMM SpMM...
Epoch: 6, Accuracy: 0.8214
Using FusedMM SpMM...
Using FusedMM SpMM...
Using FusedMM SpMM...
Using FusedMM SpMM...
Epoch: 7, Accuracy: 0.9429
Using FusedMM SpMM...
Using FusedMM SpMM...
Using FusedMM SpMM...
Using FusedMM 

In [32]:
# cProfile.run('test_GCN(False)', sort=SortKey.CUMULATIVE)
# test_GCN(True), test_GCN(False)

print("Accuracy without FusedMM: ", test_GCN(False)* 100, "%")
print("Accuracy with FusedMM: ", test_GCN(True) * 100 , "%")


Accuracy without FusedMM:  66.3 %
Using FusedMM SpMM...
Using FusedMM SpMM...
Accuracy with FusedMM:  65.8 %


In [None]:
# cProfile.run('test_GCN(True)', sort=SortKey.CUMULATIVE)
# 2 0.000649 0.000324 0.000649 0.000324 {built-in method torch._ops.isplib.spmm_sum}
# 2 0.000325 0.000162 0.000325 0.000162 {built-in method torch._ops.isplib.fusedmm_spmm}

In [36]:
import io

def get_cumulative_time(FusedMM=False):
    with cProfile.Profile() as pr:
        test_GCN(FusedMM)
        txt = io.StringIO()
        p = pstats.Stats(pr, stream=txt)
        p.print_stats('isplib.spmm_sum' if not FusedMM else 'isplib.fusedmm_spmm')
        # print(txt.getvalue())
        return txt.getvalue().strip().split('\n')[-1].split(' ')[-4]

In [43]:
torch_op_time = float(get_cumulative_time(False))
fusedmm_time = float(get_cumulative_time(True))
speedup = torch_op_time / fusedmm_time
print("Non-FusedMM time: ", torch_op_time, 'seconds')
print("FusedMM time: ", fusedmm_time, 'seconds')
print()
print("Speedup: ", f'{speedup:.3}x')

Using FusedMM SpMM...
Using FusedMM SpMM...
Non-FusedMM time:  0.000351 seconds
FusedMM time:  0.000207 seconds

Speedup:  1.7x
