Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[v1.x][LT] Add forward & backward linalg.gemm test for large size #18825

Merged
merged 2 commits into from
Jul 30, 2020
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
41 changes: 34 additions & 7 deletions tests/nightly/test_large_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1211,23 +1211,49 @@ def check_syrk_batch():
assert_almost_equal(A.grad[1,0,0], nd.array([0.4]), rtol=1e-3, atol=1e-5)

def check_gemm2():
def run_gemm2(inp1,inp2):
def run_gemm2(inp1, inp2):
inp1.attach_grad()
inp2.attach_grad()
with mx.autograd.record():
out = mx.nd.linalg.gemm2(inp1,inp2)
out = mx.nd.linalg.gemm2(inp1, inp2)
return inp1.grad, inp2.grad, out

inp1=mx.nd.ones(shape=(SMALL_Y, LARGE_X))
inp1[0][0]=0.1
inp2=mx.nd.ones(shape=(LARGE_X, SMALL_Y))
inp1_grad, inp2_grad, out= run_gemm2(inp1,inp2)
inp1 = mx.nd.ones(shape=(SMALL_Y, LARGE_X))
perturbation = 0.2
inp1[0][0] = perturbation
inp2 = mx.nd.ones(shape=(LARGE_X, SMALL_Y))
inp1_grad, inp2_grad, out = run_gemm2(inp1, inp2)
assert out.asnumpy()[0][0] == LARGE_X
assert out.shape == (SMALL_Y, SMALL_Y)
out.backward()
assert inp1_grad.shape == (SMALL_Y, LARGE_X)
assert inp2_grad.shape == (LARGE_X, SMALL_Y)
assert_almost_equal(inp2_grad.asnumpy()[0][0],49.1)
assert_almost_equal(inp1_grad.asnumpy()[0][0], SMALL_Y)
assert_almost_equal(inp2_grad.asnumpy()[0][0], SMALL_Y-(1-perturbation))
ChaiBapchya marked this conversation as resolved.
Show resolved Hide resolved

def check_gemm():
def run_gemm(inp1,inp2, inp3):
inp1.attach_grad()
inp2.attach_grad()
inp3.attach_grad()
with mx.autograd.record():
out = mx.nd.linalg.gemm(inp1, inp2, inp3, transpose_b=True)
return inp1.grad, inp2.grad, inp3.grad, out

inp1 = mx.nd.ones(shape=(MEDIUM_X, SMALL_Y, MEDIUM_X))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

spaces around =

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's incorrect according to PEP8 style guide Python
https://www.python.org/dev/peps/pep-0008/
Screen Shot 2020-07-30 at 1 05 56 AM

perturbation = 0.2
inp1[0][0][0] = perturbation
inp2 = mx.nd.ones(shape=(MEDIUM_X, SMALL_Y, MEDIUM_X))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

spaces around =

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's incorrect according to PEP8 style guide Python
https://www.python.org/dev/peps/pep-0008/
Screen Shot 2020-07-30 at 1 05 56 AM

inp3 = mx.nd.ones(shape=(MEDIUM_X, SMALL_Y, SMALL_Y))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's incorrect according to PEP8 style guide Python
https://www.python.org/dev/peps/pep-0008/
Screen Shot 2020-07-30 at 1 05 56 AM

inp1_grad, inp2_grad, inp3_grad, out= run_gemm(inp1, inp2, inp3)
ChaiBapchya marked this conversation as resolved.
Show resolved Hide resolved
assert_almost_equal(out.asnumpy()[0][0][0], MEDIUM_X+perturbation)
assert out.shape == inp3.shape
out.backward()
assert inp1_grad.shape == (MEDIUM_X, SMALL_Y, MEDIUM_X)
assert inp2_grad.shape == (MEDIUM_X, SMALL_Y, MEDIUM_X)
assert inp3_grad.shape == (MEDIUM_X, SMALL_Y, SMALL_Y)
assert_almost_equal(inp1_grad.asnumpy()[0][0][0], SMALL_Y)
assert_almost_equal(inp2_grad.asnumpy()[0][0][0], SMALL_Y-(1-perturbation))
ChaiBapchya marked this conversation as resolved.
Show resolved Hide resolved

def check_det():
def run_det(inp):
Expand Down Expand Up @@ -1340,6 +1366,7 @@ def run_trsm(inp):
assert(grad[0, 0, 0] == 0)
assert(grad[1, 0, 0] == 0)

check_gemm()
check_potrf()
check_potri()
check_syrk_batch()
Expand Down