In [2]:
import torch



In [85]:
def fast_mm(matrix_1, matrix_2, device ="cuda:1"):
    
    torch.set_printoptions(precision=15)
    size_1 = matrix_1.size()[0]
    size_2 = matrix_1.size()[1]

    matrix_1.to(device)
    matrix_2.to(device)
    
    if size_1 * size_2 <= 10000 * 100:  #good for 10000*100 matrices
        return torch.matmul(matrix_1, matrix_2.T).reshape(size_1*size_1,1).to(device)
    
#     else: 
        
#         print("Applying batches...")
        
#         batch_size = size_1 // 100
        
#         res = torch.tensor([]).to(device)
#         for i in range(matrix_1.size()[0] // batch_size):
#             A = matrix_1[batch_size*i : batch_size *(i+1)].to(device)
#             #print(A.size())
#             for j in range (matrix_2.size()[0] // batch_size):
#                 B = matrix_2[batch_size*j : batch_size *(j+1)].to(device)
#                 #print(B.size())
#                 #mult = torch.bmm(A.view(batch_size, 1, size_2), B.view(batch_size, size_2, 1)).to(device)
#                 mult = torch.matmul(A.view(batch_size, size_2), B.view(batch_size, size_2).T).reshape(batch_size*batch_size,1) 
#                 res = torch.cat((res, mult),0)
#         return res


In [4]:
def fast_mm2(matrix_1, matrix_2, device ="cuda:1"):
    
    size_1 = matrix_1.size()[0]
    size_2 = matrix_1.size()[1]
    
    matrix_1.to(device)
    matrix_1 = matrix_1.unsqueeze(1).repeat(1,size_1,1)
        
    matrix_2.to(device)
    
    res = torch.tensor([]).to(device)
    
    for i in range (size_1):
        mult = torch.bmm(matrix_1[i].view(size_1,1,size_2), matrix_2.view(size_1,size_2,1)).to(device)
        res = torch.cat((res, mult),0)
    return res.mean(1)
        
    
   

In [5]:
def vector_idx(i,j, batch_size): #return index in a vector
    return i * batch_size + j 

In [6]:
def matrix_idx(idx,batch_size): #return 2D index in a matrix
    j = idx % batch_size 
    i = idx // batch_size
    return i,j 

In [7]:
def index_by_value(tensor, values):
    return torch.nonzero(tensor == values)[0][0].item()

### Test #1: matmul

In [None]:
matrix_1 = torch.rand(5000,100)
matrix_2 = torch.rand(5000,100)
# matrix_2 = torch.rand(100, 5000*5000)

In [None]:
device = "cuda: 1"
matrix_1.to(device)
matrix_2.to(device)


In [None]:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

start.record()
result_1 = (fast_mm(matrix_1, matrix_2))
end.record()
print(result_1)
print(result_1.size())


### Test #2: bmm

In [None]:
torch.cuda.synchronize()
print(start.elapsed_time(end))
print(result_1.size())

In [None]:
start.record()
result_2 = fast_mm2(matrix_1,matrix_2)
end.record()

print(result_2)

In [None]:
torch.cuda.synchronize()
print(start.elapsed_time(end))
print(result_2.size())

In [None]:
X = torch.tensor([[2,3],[4,5],[7,8]])
Y = X.unsqueeze(1).repeat(1,3,1)
#Z = X.unsqueeze(1).repeat(1,1,1)

In [None]:
for i in range (3):
    print(torch.bmm(Y[i].view(3,1,2), X.view(3,2,1)))

In [None]:
matr = torch.tensor([[2.0,3.0],[4.0,5.0],[7.0,8.0]])
print(matr.size())
print(fast_mm2(matr,matr).size())


### Test #3: Compare matmul and bmm for small matrices

In [None]:
X = torch.tensor([[2.0,3.0],[4.0,5.0],[7.0,8.0]])
Y = torch.tensor([[1.0,2.0],[3.0,4.0],[5.0,6.0]])

In [None]:
res_bmm = fast_mm2(X,Y)
print(res_bmm)
print(res_bmm.size())

In [None]:
res_mm = fast_mm(X,Y)
print(res_mm)
print(res_mm.size())

In [None]:
print(res_mm[0].item())
print(torch.max(res_mm).item())

In [None]:
index = index_by_value(res_mm,torch.max(res_mm).item())
print(index)

In [None]:
print(matrix_idx(index,3))

In [None]:
print(vector_idx(2,2,3))

### Test #4: correctness

In [195]:
torch.set_printoptions(precision=6)
matrix_1 = torch.rand(500,100)
matrix_2 = torch.rand(500,100)
batch_size = 500
# matrix_2 = torch.rand(100, 5000*5000)

In [196]:
device = "cuda: 1"
matrix_1.to(device)
matrix_2.to(device)


tensor([[0.095547, 0.990301, 0.297882,  ..., 0.314450, 0.797820, 0.623467],
        [0.749293, 0.164841, 0.275558,  ..., 0.560175, 0.439406, 0.498890],
        [0.863934, 0.182325, 0.838508,  ..., 0.683597, 0.973021, 0.810280],
        ...,
        [0.804357, 0.209407, 0.238640,  ..., 0.766433, 0.238541, 0.050630],
        [0.587878, 0.622067, 0.606071,  ..., 0.388363, 0.916828, 0.010745],
        [0.524976, 0.311098, 0.915997,  ..., 0.454018, 0.862313, 0.824398]],
       device='cuda:1')

In [197]:
result_1 = (fast_mm(matrix_1, matrix_2)).to(device)
print(result_1.to(device))
print(result_1.size())

tensor([[26.703529357910156],
        [23.076786041259766],
        [26.795158386230469],
        ...,
        [27.729213714599609],
        [34.913894653320312],
        [26.807659149169922]], device='cuda:1')
torch.Size([250000, 1])


In [198]:
index = index_by_value(result_1,torch.min(result_1).item())
print(index)

239918


In [199]:
print(matrix_idx(index,5000))

(47, 4918)


In [200]:
print(vector_idx(198,4238,5000))

994238


In [201]:
# import numpy as np
# a = torch.matmul(matrix_1[198],matrix_2[4238]).to(device)
# a = a.item()
# print(a)
# b = result_1[994238].item()
# print(b)

# if (abs(a-b)<0.000001):
#     print(True)

In [202]:
count = 0
result_1 = (fast_mm(matrix_1, matrix_2)).to(device)
for i in range(result_1.size()[0]):
    a = matrix_idx(i,batch_size)
    i_1 = a[0]
    i_2 = a[1]
    vect_1 = torch.matmul(matrix_1[i_1],matrix_2[i_2]).to(device)
    #print(vect_1.item())
    if (abs(result_1[i].item() - vect_1.item())<0.0001):
        count = count +1
print(count)       

250000


In [203]:
result_1 = (fast_mm(matrix_1, matrix_2)).to(device)
print(result_1[555].item())

25.188858032226562


In [204]:
print(matrix_idx(555,batch_size))

(1, 55)


In [205]:
print(torch.matmul(matrix_1[1],matrix_2[55]).item())

25.188854217529297
