From a9d72211641c8547c2850bb5081af607a32cc291 Mon Sep 17 00:00:00 2001 From: Alfredo Canziani Date: Thu, 19 Oct 2023 15:56:56 -0400 Subject: [PATCH] Fix ugly patch attempt --- 15-transformer.ipynb | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/15-transformer.ipynb b/15-transformer.ipynb index eb8b367ad..0393fb3bc 100644 --- a/15-transformer.ipynb +++ b/15-transformer.ipynb @@ -27,7 +27,7 @@ "outputs": [], "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "nn_Softargmax = nn.Softmax # fix wrong name" + "f.softargmax = f.softmax # fix wrong name" ] }, { @@ -70,14 +70,14 @@ " batch_size = Q.size(0) \n", " k_length = K.size(-2) \n", " \n", - " # Scaling by d_k so that the soft(arg)max doesnt saturate\n", - " Q = Q / np.sqrt(self.d_k) # (bs, n_heads, q_length, dim_per_head)\n", - " scores = torch.matmul(Q, K.transpose(2,3)) # (bs, n_heads, q_length, k_length)\n", + " # Scaling by d_k so that the softargmax doesnt saturate\n", + " Q = Q / np.sqrt(self.d_k) # (bs, n_heads, q_length, dim_per_head)\n", + " scores = torch.matmul(Q, K.transpose(2,3)) # (bs, n_heads, q_length, k_length)\n", " \n", - " A = nn_Softargmax(dim=-1)(scores) # (bs, n_heads, q_length, k_length)\n", + " A = f.softargmax(scores, dim=-1) # (bs, n_heads, q_length, k_length)\n", " \n", " # Get the weighted average of the values\n", - " H = torch.matmul(A, V) # (bs, n_heads, q_length, dim_per_head)\n", + " H = torch.matmul(A, V) # (bs, n_heads, q_length, dim_per_head)\n", "\n", " return H, A \n", "\n", @@ -674,9 +674,7 @@ { "cell_type": "code", "execution_count": 33, - "metadata": { - "scrolled": false - }, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -749,7 +747,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 [conda env:pDL]", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -763,9 +761,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.2" + "version": "3.10.12" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 }