In [None]:
# test fc 
import torch 
from tests.datagen import save_tensor_to_csv
import os
xh = 3
xw = 4
th = 2
inner = 5
torch.set_printoptions(precision=10, linewidth=2000)
torch.random.manual_seed(20)
W0 = torch.nn.Parameter(torch.randn(inner, xh) )
b0 = torch.nn.Parameter(torch.randn(inner, 1) )
W1 = torch.nn.Parameter(torch.randn(th, inner) )
b1 = torch.nn.Parameter(torch.randn(th, 1) )

x = torch.ones(xh, xw)
y1 = torch.mm(W0, x) + b0
z1 = torch.sigmoid(y1)
y2 = torch.mm(W1, z1) + b1
z2 = torch.tanh(y2)

os.remove("static_data/fc.txt")
save_tensor_to_csv(W0, "static_data/fc.txt")
save_tensor_to_csv(b0, "static_data/fc.txt")
save_tensor_to_csv(W1, "static_data/fc.txt")
save_tensor_to_csv(b1, "static_data/fc.txt")

exp = torch.exp(z2)
s = exp / exp.sum(dim=0, keepdim=True)

# cross entropy:
t = torch.ones_like(s) * 5
#print("t: ", t, "\ns: ", s)
ce = -(t * torch.log(s)).sum()
#ce = (s - t).pow(2).mean()
ce.backward()
print("error: " , ce)

save_tensor_to_csv(W0.grad, "static_data/fc.txt")
save_tensor_to_csv(b0.grad, "static_data/fc.txt")
save_tensor_to_csv(W1.grad, "static_data/fc.txt")
save_tensor_to_csv(b1.grad, "static_data/fc.txt")

In [None]:
# test attention
import torch
import os
import math
from tests.datagen import save_tensor_to_csv
torch.manual_seed(10)
torch.set_printoptions(precision=8, linewidth=2000)

Ei = 48 # input embedding size
Eq = 16  # query embedding size
Ek = Eq # key embedding size
Ev = Ei # value , i.e. output embedding size
S = 5  # seq_len

Q = torch.nn.Parameter(torch.randn(Eq, Ei) * math.sqrt(1/Ei))
K = torch.nn.Parameter(torch.randn(Ek, Ei) * math.sqrt(1/Ei))
V = torch.nn.Parameter(torch.randn(Ev, Ei) * math.sqrt(1/Ei))

q = torch.rand(S, Ei)
k = torch.rand(S, Ei)
v = torch.rand(S, Ei)

q_ = q @ Q.t()
k_ = k @ K.t()
v_ = v @ V.t()

q_.retain_grad()
k_.retain_grad()
v_.retain_grad()

print("Q.shape: ", Q.shape, "K.shape: ", K.shape, "V.shape: ", V.shape)
print("q_.shape: ", q_.shape, "k_.shape: ", k_.shape, "v_.shape: ", v_.shape)

qkt = q_ @ k_.t() / (Eq ** .5)
print("qkt.shape: ", qkt.shape)
qkt.retain_grad()
smax = qkt.exp() / qkt.exp().sum(dim=0, keepdim=True)
smax.retain_grad()
output = smax @ v_
output.retain_grad()
print("otuput shape: ", output.shape)

t = torch.ones_like(output)
loss = (t - output).pow(2).mean()
loss.backward()

l2_grad = 2 * (output - t) / (S * Ev)

qkt_grad_in = torch.mm(smax.t(), l2_grad)

v_grad_in = torch.mm(smax.t(), l2_grad)
q_grad_in = torch.mm(qkt.grad, k_) / (Eq ** .5)
k_grad_in = torch.mm(q_.t(), qkt.grad).t() / (Eq ** .5)

Q_grad_in = torch.mm(q.t(), q_grad_in).t()
K_grad_in = torch.mm(k.t(), k_grad_in).t()
V_grad_in = torch.mm(v.t(), v_grad_in).t()

assert(torch.allclose(v_.grad, v_grad_in))
assert(torch.allclose(q_.grad, q_grad_in))
assert(torch.allclose(k_.grad, k_grad_in))
assert(torch.allclose(Q.grad, Q_grad_in))
assert(torch.allclose(K.grad, K_grad_in))
assert(torch.allclose(V.grad, V_grad_in))

file = "static_data/attention.txt"
try:
    os.remove(file)
except:
    pass
save_tensor_to_csv(Q, file)
save_tensor_to_csv(K, file)
save_tensor_to_csv(V, file)
save_tensor_to_csv(q, file)
save_tensor_to_csv(k, file)
save_tensor_to_csv(v, file)
save_tensor_to_csv(output, file)
save_tensor_to_csv(Q.grad, file)
save_tensor_to_csv(K.grad, file)
save_tensor_to_csv(V.grad, file)

print("output.grad:\n ", output.grad)
print("smax.grad:\n ", smax.grad)
print("qkt.grad:\n ", qkt.grad)
print("q.grad:\n ", q_.grad)
print("k.grad:\n ", k_.grad)
print("v.grad:\n ", v_.grad)

In [None]:
# test linear
import torch 
from tests.datagen import save_tensor_to_csv
import os
s = 3
e = 4
l0 = 5
l1 = 6

torch.set_printoptions(precision=10, linewidth=2000)
torch.random.manual_seed(20)
W0 = torch.nn.Parameter(torch.randn(l0, e) )
W1 = torch.nn.Parameter(torch.randn(l1, l0) )
x = torch.randn(s, e)

l0 = torch.mm(x, W0.t())
l1 = torch.mm(l0, W1.t())

file = "static_data/linear.txt"
try:
    os.remove(file)
except:
    pass

print(l1.shape)
print("W0.shape: ", W0.shape, "W1.shape: ", W1.shape, "x.shape: ", x.shape)

loss = (l1 - torch.ones_like(l1)).pow(2).mean()
loss.backward()

print("Loss: ", loss, "\nW0.grad: ", W0.grad, "\nW1.grad: ", W1.grad)

save_tensor_to_csv(W0, file)
save_tensor_to_csv(W1, file)
save_tensor_to_csv(x, file)
save_tensor_to_csv(W0.grad, file)
save_tensor_to_csv(W1.grad, file)