In [6]:
import torch
import torch.nn as nn
from torch.nn import functional as F


In [101]:
x = torch.randn((10,5), requires_grad=True)
y = torch.randint(0,3, (10,))
y

tensor([2, 1, 0, 1, 2, 1, 2, 0, 0, 0])

In [102]:
li = nn.Linear(in_features=5, out_features=10)
li2 = nn.Linear(10,3)

li_out = li(x)
out = li2(li_out)
out.retain_grad()
x.retain_grad()
li_out.retain_grad()

In [103]:
li2.weight.grad

In [104]:
loss = F.cross_entropy(out, y)
loss

tensor(1.1918, grad_fn=<NllLossBackward0>)

In [105]:
loss.backward()

In [106]:
-torch.log(out[range(10), y])  

tensor([0.5113, 1.1901, 1.0210,    nan, 0.6620, 2.4949, 0.3594, 1.0085, 0.6821,
           nan], grad_fn=<NegBackward0>)

In [107]:
class CrossEntropyLoss:

    def __call__(self,
                 y_pred: torch.Tensor,
                 y_true: torch.Tensor
                 ):
        n_samples = y_pred.shape[0]
        y_preds = y_pred - y_pred.max(1, keepdim=True).values
        counts = y_preds.exp()
        counts_sum = counts.sum(1, keepdim=True)
        counts_sum_inv = counts_sum ** -1
        probs = counts * counts_sum_inv
        logprobs = probs.log()
        self.out = -logprobs[range(n_samples), y_true].mean()
        return self.out

    def backward(self,
                 y_pred: torch.Tensor,
                 y_true: torch.Tensor
                 ):
        n_samples = y_pred.shape[0]
        # softmax = F.softmax()
        grad = F.softmax(y_pred, dim=1)
        grad[range(n_samples), y_true] -= 1
        grad = grad / n_samples
        return grad

    def paramerters(self):
        return []
    
lo = CrossEntropyLoss()
lo(out, y)

tensor(1.1918, grad_fn=<NegBackward0>)

In [108]:

dloss = lo.backward(out, y)

print(f"shape of dloss{dloss.shape}, out {out.shape}")
print(dloss)

shape of dlosstorch.Size([10, 3]), out torch.Size([10, 3])
tensor([[ 0.0240,  0.0347, -0.0587],
        [ 0.0299, -0.0671,  0.0373],
        [-0.0646,  0.0288,  0.0357],
        [ 0.0339, -0.0801,  0.0462],
        [ 0.0240,  0.0305, -0.0545],
        [ 0.0206, -0.0763,  0.0557],
        [ 0.0295,  0.0259, -0.0554],
        [-0.0686,  0.0274,  0.0413],
        [-0.0670,  0.0303,  0.0366],
        [-0.0860,  0.0463,  0.0397]], grad_fn=<DivBackward0>)


In [109]:
li2.weight.grad

tensor([[ 0.1641, -0.0714, -0.0778,  0.0775,  0.0259,  0.0220,  0.0946, -0.0208,
          0.0993,  0.0915],
        [-0.1298,  0.0790,  0.0594, -0.1384, -0.0275, -0.0181, -0.0883,  0.0483,
         -0.1244, -0.0467],
        [-0.0343, -0.0076,  0.0184,  0.0609,  0.0015, -0.0039, -0.0063, -0.0275,
          0.0250, -0.0448]])

In [110]:
print(li_out.shape, dloss.shape)
dw = dloss.T @ li_out
dw

torch.Size([10, 10]) torch.Size([10, 3])


tensor([[ 0.1641, -0.0714, -0.0778,  0.0775,  0.0259,  0.0220,  0.0946, -0.0208,
          0.0993,  0.0915],
        [-0.1298,  0.0790,  0.0594, -0.1384, -0.0275, -0.0181, -0.0883,  0.0483,
         -0.1244, -0.0467],
        [-0.0343, -0.0076,  0.0184,  0.0609,  0.0015, -0.0039, -0.0063, -0.0275,
          0.0250, -0.0448]], grad_fn=<MmBackward0>)

In [111]:
torch.allclose(dw, li2.weight.grad)

True

In [112]:
db = dloss.sum(dim = 0)
li2.bias.grad, db

(tensor([-0.1244,  0.0004,  0.1240]),
 tensor([-0.1244,  0.0004,  0.1240], grad_fn=<SumBackward1>))

In [113]:

dli_out =  dloss @ li2.weight
dli_out

tensor([[-3.0885e-03, -4.2290e-03, -1.7008e-02, -1.9706e-02,  6.0901e-03,
         -6.6475e-03, -1.6097e-02, -1.0117e-03, -4.1209e-03,  4.6926e-03],
        [ 1.2462e-03,  1.3961e-02, -1.0619e-02,  2.8985e-02, -1.1432e-02,
          1.2178e-02,  2.0143e-02,  3.1030e-03,  4.2178e-03,  5.3580e-03],
        [ 2.6719e-03, -9.9180e-03,  3.4079e-02, -6.2542e-03,  4.6740e-03,
         -4.7698e-03, -1.1945e-03, -2.1102e-03,  7.3434e-04, -1.2093e-02],
        [ 1.5972e-03,  1.6520e-02, -1.1650e-02,  3.4791e-02, -1.3646e-02,
          1.4544e-02,  2.4287e-02,  3.6749e-03,  5.1194e-03,  6.0543e-03],
        [-2.8978e-03, -3.4967e-03, -1.6629e-02, -1.7680e-02,  5.3689e-03,
         -5.8718e-03, -1.4578e-02, -8.4562e-04, -3.7682e-03,  4.6815e-03],
        [ 2.2472e-03,  1.4847e-02, -4.4143e-03,  3.4539e-02, -1.3051e-02,
          1.3957e-02,  2.4820e-02,  3.3243e-03,  5.4523e-03,  3.5500e-03],
        [-3.0229e-03, -2.2650e-03, -1.9320e-02, -1.6071e-02,  4.5882e-03,
         -5.0540e-03, -1.3668e-0

In [115]:
torch.allclose(li_out.grad, dli_out)

True

In [116]:
out.grad

tensor([[ 0.0240,  0.0347, -0.0587],
        [ 0.0299, -0.0671,  0.0373],
        [-0.0646,  0.0288,  0.0357],
        [ 0.0339, -0.0801,  0.0462],
        [ 0.0240,  0.0305, -0.0545],
        [ 0.0206, -0.0763,  0.0557],
        [ 0.0295,  0.0259, -0.0554],
        [-0.0686,  0.0274,  0.0413],
        [-0.0670,  0.0303,  0.0366],
        [-0.0860,  0.0463,  0.0397]])

In [117]:
dw1 = dli_out.T @ x
dw1, li.weight.grad
torch.allclose(dw1, li.weight.grad)

True

In [118]:
db1 = dli_out.sum(0)
torch.allclose(db1, li.bias.grad)

True

In [119]:
print(dli_out.shape, li.weight.shape)
dx = dli_out @ li.weight

torch.Size([10, 10]) torch.Size([10, 5])


In [121]:
x.grad, dx

(tensor([[-3.0510e-05, -3.7117e-05,  1.1785e-02,  7.2215e-03, -2.1346e-03],
         [-4.9092e-03, -1.2360e-02, -1.1880e-02,  1.4466e-04,  3.0314e-03],
         [ 5.4787e-03,  1.3741e-02, -2.3019e-03, -9.6354e-03, -5.5749e-04],
         [-5.7404e-03, -1.4454e-02, -1.4428e-02, -1.5749e-04,  3.6420e-03],
         [-2.1783e-04, -5.1056e-04,  1.0783e-02,  6.8930e-03, -1.9194e-03],
         [-4.7046e-03, -1.1858e-02, -1.5418e-02, -2.3170e-03,  3.6374e-03],
         [-7.8253e-04, -1.9288e-03,  1.0445e-02,  7.5358e-03, -1.7574e-03],
         [ 5.6815e-03,  1.4247e-02, -3.2364e-03, -1.0509e-02, -4.2391e-04],
         [ 5.7016e-03,  1.4301e-02, -2.2830e-03, -9.9589e-03, -6.0061e-04],
         [ 7.6506e-03,  1.9195e-02, -1.1681e-03, -1.2209e-02, -1.1501e-03]]),
 tensor([[-3.0510e-05, -3.7116e-05,  1.1785e-02,  7.2215e-03, -2.1346e-03],
         [-4.9092e-03, -1.2360e-02, -1.1880e-02,  1.4466e-04,  3.0314e-03],
         [ 5.4787e-03,  1.3741e-02, -2.3019e-03, -9.6354e-03, -5.5749e-04],
         [

In [None]:
class Linear:
    def __init__(self,
                 fan_in: int,
                 fan_out: int,
                 bias=True):
        self.weight = torch.randn((fan_in, fan_out)) / fan_in ** 0.5
        self.bias = torch.randn(fan_out) if bias else None

    def __call__(self,
                 X: torch.Tensor):
        self.last_input = X
        self.out = X @ self.weight
        if self.bias is not None:
            self.out += self.bias
        return self.out

    def backward(self, d_L_d_out):
        # d_L_d_weights = torch.matmul(self.last_input.t(), d_L_d_out)

        d_L_d_weights = self.last_input.T @ d_L_d_out
        d_L_d_biases = torch.sum(d_L_d_out, dim=0)
        d_L_d_input = d_L_d_out @ self.weight.T

        return d_L_d_input, d_L_d_weights, d_L_d_biases

    def parameters(self):
        return [self.weight] + ([] if self.bias is None else [self.bias])

In [37]:
logit_maxes = out.max(1, keepdim=True).values
norm_logits = out - logit_maxes #subtract the max for numerical stability refer the previous notebooks 
counts = norm_logits.exp()
counts_sum = counts.sum(1, keepdims = True)
counts_sum_inv = counts_sum ** -1  # if I use (1.0 / counts_sum) instead then I can't get backprop to be bit exact...
probs = counts * counts_sum_inv
logprobs = probs.log()
loss = -logprobs[range(10), y].mean()

In [38]:
loss

tensor(1.1548, grad_fn=<NegBackward0>)

In [39]:
logit_maxes, norm_logits

(tensor([[0.2092],
         [0.5867],
         [0.1870],
         [0.0802],
         [0.2657],
         [0.2117],
         [0.1744],
         [0.3677],
         [0.2169],
         [0.2833]], grad_fn=<MaxBackward0>),
 tensor([[-0.5269,  0.0000, -0.3183],
         [ 0.0000, -1.0006, -0.8468],
         [-0.0870, -0.0782,  0.0000],
         [-0.1711,  0.0000, -0.2601],
         [-0.0527,  0.0000, -0.5498],
         [-0.1575, -0.1237,  0.0000],
         [ 0.0000, -0.2957, -0.2016],
         [ 0.0000, -0.2819, -0.6581],
         [-0.0386, -0.0865,  0.0000],
         [ 0.0000, -0.1860, -0.6365]], grad_fn=<SubBackward0>))

In [40]:
counts

tensor([[0.5904, 1.0000, 0.7274],
        [1.0000, 0.3677, 0.4288],
        [0.9166, 0.9248, 1.0000],
        [0.8427, 1.0000, 0.7709],
        [0.9487, 1.0000, 0.5770],
        [0.8543, 0.8836, 1.0000],
        [1.0000, 0.7440, 0.8174],
        [1.0000, 0.7543, 0.5178],
        [0.9622, 0.9171, 1.0000],
        [1.0000, 0.8302, 0.5291]], grad_fn=<ExpBackward0>)

## with Non Linearity

In [123]:
torch.manual_seed(42)
x = torch.randn((10, 5), requires_grad=True)
y = torch.randint(0,3,(10,))

print(x.shape, y.shape)

torch.Size([10, 5]) torch.Size([10])


In [125]:
li = nn.Linear(5, 10)
li2 = nn.Linear(10,3)
re = nn.GELU()

li_out = li(x)
re_out = re(li_out)
li2_out = li2(re_out)

for op in [li_out, re_out, li2_out]:
    op.retain_grad()
print(li2_out)

tensor([[-0.3346,  0.3682, -0.4057],
        [ 0.0695,  0.0743, -0.3219],
        [ 0.1473, -0.1134, -0.1617],
        [-0.0190,  0.3161, -0.2893],
        [ 0.3820, -0.0245, -0.5353],
        [ 0.0672,  0.3544, -0.4217],
        [ 0.0575,  0.1941, -0.1656],
        [ 0.0911,  0.2511, -0.3718],
        [ 0.2456, -0.1051, -0.4400],
        [ 0.1275,  0.1981, -0.3591]], grad_fn=<AddmmBackward0>)


In [127]:
loss = F.cross_entropy(li2_out, y)
loss

tensor(1.1506, grad_fn=<NllLossBackward0>)

In [128]:
loss.backward()

In [137]:
def manual_loss(y_pred, y_true):
    n_samples = y_pred.shape[0]
    y_pred_max = y_pred.max(1, keepdim = True).values
    y_pred_max = y_pred - y_pred_max
    logprobs  = (y_pred_max.exp()/ y_pred_max.exp().sum(1, keepdim =True)).log()
    loss = -logprobs[range(0, n_samples), y_true].mean()
    return loss

def lbackward(y_pred, y_true):
    n_samples = y_pred.shape[0]
    grad = F.softmax(y_pred, dim = 1)
    grad[range(0,n_samples), y_true] -= 1
    grad = grad/ n_samples
    return grad
    
    

man_l = manual_loss(li2_out, y)
dloss = lbackward(li2_out, y)

print(man_l, dloss)    

tensor(1.1506, grad_fn=<NegBackward0>) tensor([[-0.0747,  0.0511,  0.0236],
        [ 0.0373, -0.0625,  0.0252],
        [ 0.0399, -0.0692,  0.0293],
        [ 0.0316, -0.0558,  0.0241],
        [ 0.0484,  0.0322, -0.0807],
        [-0.0661,  0.0452,  0.0208],
        [ 0.0339, -0.0611,  0.0272],
        [ 0.0357, -0.0581,  0.0225],
        [ 0.0453, -0.0681,  0.0228],
        [ 0.0372,  0.0399, -0.0771]], grad_fn=<DivBackward0>)


In [139]:
li2_out.grad

tensor([[-0.0747,  0.0511,  0.0236],
        [ 0.0373, -0.0625,  0.0252],
        [ 0.0399, -0.0692,  0.0293],
        [ 0.0316, -0.0558,  0.0241],
        [ 0.0484,  0.0322, -0.0807],
        [-0.0661,  0.0452,  0.0208],
        [ 0.0339, -0.0611,  0.0272],
        [ 0.0357, -0.0581,  0.0225],
        [ 0.0453, -0.0681,  0.0228],
        [ 0.0372,  0.0399, -0.0771]])

In [141]:
def softmax(x, dim:int):
    x = x - x.max(dim, keepdim=True).values
    out = x.exp()/ x.exp().sum(dim = dim, keepdim =True)
    return out

softmax(li2_out, 1)

tensor([[0.2531, 0.5112, 0.2357],
        [0.3730, 0.3748, 0.2522],
        [0.3992, 0.3076, 0.2931],
        [0.3163, 0.4423, 0.2414],
        [0.4841, 0.3224, 0.1934],
        [0.3394, 0.4524, 0.2082],
        [0.3394, 0.3891, 0.2715],
        [0.3568, 0.4187, 0.2246],
        [0.4529, 0.3189, 0.2282],
        [0.3720, 0.3993, 0.2287]], grad_fn=<DivBackward0>)

In [142]:
F.softmax(li2_out, dim=1)

tensor([[0.2531, 0.5112, 0.2357],
        [0.3730, 0.3748, 0.2522],
        [0.3992, 0.3076, 0.2931],
        [0.3163, 0.4423, 0.2414],
        [0.4841, 0.3224, 0.1934],
        [0.3394, 0.4524, 0.2082],
        [0.3394, 0.3891, 0.2715],
        [0.3568, 0.4187, 0.2246],
        [0.4529, 0.3189, 0.2282],
        [0.3720, 0.3993, 0.2287]], grad_fn=<SoftmaxBackward0>)

In [144]:
dw2 = dloss.T @ re_out
dw2

tensor([[ 0.0118,  0.0408,  0.0765, -0.0865,  0.0398, -0.0183,  0.0924,  0.1141,
         -0.1180, -0.0372],
        [-0.0151, -0.0544, -0.0540,  0.0724,  0.0017,  0.0256, -0.0657, -0.1810,
          0.0606,  0.0311],
        [ 0.0033,  0.0136, -0.0224,  0.0141, -0.0416, -0.0072, -0.0267,  0.0669,
          0.0574,  0.0061]], grad_fn=<MmBackward0>)

In [146]:
li2.weight.grad

tensor([[ 0.0118,  0.0408,  0.0765, -0.0865,  0.0398, -0.0183,  0.0924,  0.1141,
         -0.1180, -0.0372],
        [-0.0151, -0.0544, -0.0540,  0.0724,  0.0017,  0.0256, -0.0657, -0.1810,
          0.0606,  0.0311],
        [ 0.0033,  0.0136, -0.0224,  0.0141, -0.0416, -0.0072, -0.0267,  0.0669,
          0.0574,  0.0061]])

In [151]:
db2 = dloss.sum(0)
db2

tensor([ 0.1686, -0.2063,  0.0377], grad_fn=<SumBackward1>)

In [150]:
li2.bias.grad

tensor([ 0.1686, -0.2063,  0.0377])

In [156]:
dre_out = dloss @ li2.weight
torch.allclose(dre_out, re_out.grad)

True

In [160]:
dw1 = dre_out.T @ li_out
dw1.shape

torch.Size([10, 10])

In [162]:
li_out.shape

torch.Size([10, 10])

In [165]:
li_out.grad

tensor([[-1.5644e-03, -3.0173e-03, -6.0959e-03, -1.6417e-02, -2.8712e-03,
         -1.8655e-03,  1.2507e-03, -4.9156e-03,  2.0204e-02,  1.7993e-02],
        [ 4.0592e-03, -2.2008e-02,  7.8039e-03, -4.7219e-04,  1.4989e-02,
          2.5203e-03,  7.6645e-03,  1.8097e-02, -2.6238e-03, -1.1918e-04],
        [ 9.1209e-03, -9.5669e-03,  9.3411e-03,  6.0112e-04,  4.2643e-03,
          1.7493e-03,  9.0313e-03,  2.1762e-02, -1.1307e-02, -3.0412e-05],
        [-6.3458e-04, -1.2319e-03,  5.5007e-03,  5.2362e-03,  5.8433e-03,
          5.1021e-04,  1.7442e-03,  8.5731e-03, -1.9086e-02,  4.3253e-04],
        [-2.1148e-02, -5.0481e-05,  2.6310e-02,  2.7916e-03,  1.8492e-02,
          2.5541e-03,  1.8069e-02,  1.8258e-03,  1.3345e-03, -2.1255e-02],
        [ 1.5732e-04, -2.4007e-03, -1.6478e-02, -1.2023e-02, -7.6808e-03,
         -1.6502e-03, -9.3241e-03, -2.1535e-03,  1.6880e-02,  8.2291e-03],
        [ 1.6962e-03, -9.5929e-03,  6.6492e-03,  2.3927e-03,  5.4254e-03,
          8.7450e-04,  2.9921e-0

In [163]:
li.weight.grad

tensor([[-0.0016, -0.0471, -0.0307, -0.0382, -0.0311],
        [ 0.0360,  0.0106,  0.0528,  0.0267, -0.0328],
        [-0.0655, -0.0119, -0.0209,  0.0398,  0.0311],
        [-0.0345, -0.0274, -0.0199,  0.0279, -0.0071],
        [-0.0468,  0.0008, -0.0265,  0.0191,  0.0481],
        [-0.0068, -0.0082, -0.0030,  0.0016,  0.0062],
        [-0.0172, -0.0224,  0.0040,  0.0031,  0.0364],
        [-0.0846, -0.0041, -0.0890, -0.0092, -0.0094],
        [ 0.0709,  0.0113,  0.0724, -0.0117,  0.0346],
        [ 0.0581,  0.0125,  0.0067, -0.0736, -0.0160]])

In [164]:
li.bias.grad

tensor([ 0.0031, -0.0603,  0.0610, -0.0078,  0.0565,  0.0111,  0.0622,  0.0689,
        -0.0233,  0.0037])

In [166]:
from typing import Any
import math

class GELU:
    
    def __call__(self, x:torch.Tensor) -> Any:
        self.out = 0.5 * x * (1 + torch.tanh(math.sqrt(2/math.pi) * (x + 0.044715 * torch.pow(x,3))))        
        return self.out
    
gel = GELU()
gel_out = gel(li_out)

gel_out        

tensor([[-0.0871, -0.1119, -0.1381,  0.6630, -0.1627, -0.1595, -0.1665,  0.2566,
          1.6397,  0.4571],
        [-0.0831,  0.8619,  0.0823, -0.1654,  0.3347, -0.0838,  0.2873,  0.7110,
         -0.1560,  0.0233],
        [ 0.1213, -0.0403,  0.1530, -0.1654, -0.1285, -0.1293,  0.4315,  1.1305,
         -0.0093, -0.1575],
        [-0.1657, -0.1653,  0.0147,  0.2009, -0.0602, -0.1633, -0.1286,  0.0303,
          0.6222,  0.3229],
        [ 0.0151,  0.1042,  0.7081, -0.1081,  0.4864, -0.1419,  0.3370, -0.1576,
         -0.1525,  0.1868],
        [-0.1698, -0.1404,  0.0699,  0.3278, -0.1157, -0.1595, -0.0402, -0.0595,
          0.7224, -0.0090],
        [-0.1542, -0.0031,  0.0649, -0.0948, -0.0856, -0.1539, -0.0726,  0.6602,
          0.2684,  0.0561],
        [-0.1235, -0.1461,  0.2469,  0.0871, -0.1694, -0.1499, -0.0322,  0.4044,
          0.7016, -0.1574],
        [ 0.2839,  0.0905,  0.4717, -0.1684,  0.0582, -0.1258,  0.6478,  0.2361,
         -0.1516, -0.1700],
        [-0.1700, -

In [167]:
re_out

tensor([[-0.0869, -0.1117, -0.1380,  0.6631, -0.1626, -0.1595, -0.1664,  0.2566,
          1.6398,  0.4571],
        [-0.0831,  0.8620,  0.0823, -0.1653,  0.3347, -0.0838,  0.2873,  0.7111,
         -0.1560,  0.0233],
        [ 0.1213, -0.0403,  0.1530, -0.1654, -0.1285, -0.1293,  0.4315,  1.1307,
         -0.0093, -0.1573],
        [-0.1656, -0.1652,  0.0147,  0.2009, -0.0602, -0.1633, -0.1286,  0.0303,
          0.6223,  0.3229],
        [ 0.0151,  0.1042,  0.7083, -0.1081,  0.4864, -0.1419,  0.3370, -0.1575,
         -0.1525,  0.1868],
        [-0.1698, -0.1402,  0.0699,  0.3278, -0.1157, -0.1595, -0.0402, -0.0595,
          0.7225, -0.0090],
        [-0.1542, -0.0031,  0.0649, -0.0948, -0.0856, -0.1539, -0.0726,  0.6603,
          0.2684,  0.0561],
        [-0.1235, -0.1459,  0.2469,  0.0871, -0.1694, -0.1499, -0.0322,  0.4044,
          0.7017, -0.1574],
        [ 0.2839,  0.0905,  0.4717, -0.1684,  0.0582, -0.1258,  0.6478,  0.2361,
         -0.1515, -0.1700],
        [-0.1699, -

In [168]:
torch.allclose(gel_out, re_out)

False

In [170]:
def gelu(x):
    cdf = 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
    return x * cdf

gelu(li_out)

tensor([[-0.0869, -0.1117, -0.1380,  0.6631, -0.1626, -0.1595, -0.1664,  0.2566,
          1.6398,  0.4571],
        [-0.0831,  0.8620,  0.0823, -0.1653,  0.3347, -0.0838,  0.2873,  0.7111,
         -0.1560,  0.0233],
        [ 0.1213, -0.0403,  0.1530, -0.1654, -0.1285, -0.1293,  0.4315,  1.1307,
         -0.0093, -0.1573],
        [-0.1656, -0.1652,  0.0147,  0.2009, -0.0602, -0.1633, -0.1286,  0.0303,
          0.6223,  0.3229],
        [ 0.0151,  0.1042,  0.7083, -0.1081,  0.4864, -0.1419,  0.3370, -0.1575,
         -0.1525,  0.1868],
        [-0.1698, -0.1402,  0.0699,  0.3278, -0.1157, -0.1595, -0.0402, -0.0595,
          0.7225, -0.0090],
        [-0.1542, -0.0031,  0.0649, -0.0948, -0.0856, -0.1539, -0.0726,  0.6603,
          0.2684,  0.0561],
        [-0.1235, -0.1459,  0.2469,  0.0871, -0.1694, -0.1499, -0.0322,  0.4044,
          0.7017, -0.1574],
        [ 0.2839,  0.0905,  0.4717, -0.1684,  0.0582, -0.1258,  0.6478,  0.2361,
         -0.1515, -0.1700],
        [-0.1699, -

In [171]:
torch.allclose(gelu(li_out), re_out)

True