# **Self-Attention Without Trainable Weights**

## **`1 context-vector` generation through Self-Attention**

In [5]:
import torch

x = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

In [6]:
# Let's calculate the context_vector for x^2

x_2 = x[1]
attn_score_2 = torch.sum(x * x_2, dim=-1)
attn_score_2 # Attention Scores

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])

In [7]:
# In other words we are doing:
attn_score_2 = torch.empty(size=(x.shape[0],)) # x.shape[0] : Represents the number of Tokens in the batch
for i, x_i in enumerate(x):
    product = x_2 * x_i
    attn_score_2[i] = product.sum()
print(attn_score_2)

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


In [8]:
attn_weight_2 = attn_score_2.softmax(-1)
print(attn_weight_2)
print(attn_weight_2.sum()) # 1.0

tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
tensor(1.)


In [9]:
attn_weight_2

tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])

In [None]:
# x.shape # (6, 3) 
context_vector_2 = torch.zeros(x.shape[-1]) # torch.empty(3)
# context_vector_2 # [0., 0., 0.]

for i, x_i in enumerate(x):
    # print(x_i * attn_weight_2[i])
    context_vector_2 += (x_i * attn_weight_2[i])

context_vector_2 # Context Vectors for x_2

tensor([0.4419, 0.6515, 0.5683])

## **Generating `All` Context Vectors**

In [18]:
import torch

x = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

In [19]:
attn_scores = x @ x.T
attn_scores

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

In [22]:
attn_weights = attn_scores.softmax(-1)
attn_weights

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])

In [None]:
attn_weights # (6, 6)
x            # (6, 3)

In [None]:
context_vectors = attn_weights @ x
context_vectors # For all the tokens. Enriched Embedding vectors with information about the surrounding Vectors

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])

# **Self-Attention With Trainable Weights**

In [None]:
# Attn(Q, K, V) = softmax(QK^T / d_k**0.5). V

In [25]:
import torch
import torch.nn as nn

x = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

In [83]:
d_in = x.shape[-1] # 3
d_out = 3

W_q = nn.Linear(d_in, d_out, bias=False)
W_k = nn.Linear(d_in, d_out, bias=False)
W_v = nn.Linear(d_in, d_out, bias=False)

Q = W_q(x)
K = W_q(x)
V = W_q(x)

d_k = K.shape[-1]

context_vectors = (torch.softmax( (Q @ K.T) / d_k**0.5, -1 )) @ V
context_vectors

tensor([[-0.0100,  0.4556,  0.4361],
        [-0.0115,  0.4631,  0.4445],
        [-0.0121,  0.4623,  0.4435],
        [-0.0145,  0.4554,  0.4348],
        [-0.0267,  0.4404,  0.4145],
        [-0.0084,  0.4640,  0.4462]], grad_fn=<MmBackward0>)

In [89]:
class SelfAttention(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias = False):
        '''Let's Consider Batched inputs'''
        super().__init__()
        self.W_q = nn.Linear(in_features=d_in, out_features=d_out, bias=qkv_bias)
        self.W_k = nn.Linear(in_features=d_in, out_features=d_out, bias=qkv_bias)
        self.W_v = nn.Linear(in_features=d_in, out_features=d_out, bias=qkv_bias)
    
    def forward(self, x):
        '''x: 3D matrix, with (batch_size, n_tokens, d_in)'''
        Q = self.W_q(x) # (8, 6, 2)
        K = self.W_k(x)
        V = self.W_v(x)
        attention_score = torch.matmul(Q, K.transpose(-1, -2))  
        attention_weights = torch.softmax(attention_score/ K.shape[-1]**0.5 , dim=-1)
        context_vectors = attention_weights @ V
        return context_vectors

inputs = torch.rand(size=(8, 6, 3)) # (6 tokens in each batch with dimensions = 3)
selfAttention = SelfAttention(d_in= inputs.shape[-1], d_out= 2)
selfAttention(inputs).shape

torch.Size([8, 6, 2])

## **Masking**

In [None]:
x = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

W_q = nn.Linear(x.shape[-1], x.shape[-1])
W_k = nn.Linear(x.shape[-1], x.shape[-1])
W_v = nn.Linear(x.shape[-1], x.shape[-1])

Q = W_q(x) 
K = W_k(x)
V = W_v(x)

attention_score = torch.matmul(Q, K.transpose(-1, -2))  
attention_weights = torch.softmax(attention_score/ K.shape[-1]**0.5 , dim=-1)
attention_weights # (6 x 6)

torch.Size([6, 6])


tensor([[0.1699, 0.1694, 0.1689, 0.1647, 0.1568, 0.1702],
        [0.1709, 0.1689, 0.1677, 0.1675, 0.1444, 0.1805],
        [0.1709, 0.1687, 0.1675, 0.1677, 0.1446, 0.1806],
        [0.1650, 0.1758, 0.1743, 0.1638, 0.1423, 0.1788],
        [0.1661, 0.1681, 0.1670, 0.1698, 0.1486, 0.1804],
        [0.1662, 0.1775, 0.1759, 0.1620, 0.1403, 0.1781]],
       grad_fn=<SoftmaxBackward0>)

In [None]:
# Redoing üòã:
attention_score = torch.matmul(Q, K.transpose(-1, -2))  / K.shape[-1]**0.5
mask = attention_score.triu(1).bool()
attention_score = attention_score.masked_fill(mask, -torch.inf)
attention_score 

tensor([[0.2256,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.4571, 0.4453,   -inf,   -inf,   -inf,   -inf],
        [0.4504, 0.4378, 0.4303,   -inf,   -inf,   -inf],
        [0.4881, 0.5516, 0.5433, 0.4812,   -inf,   -inf],
        [0.3186, 0.3305, 0.3242, 0.3406, 0.2071,   -inf],
        [0.5518, 0.6177, 0.6087, 0.5264, 0.3828, 0.6209]],
       grad_fn=<MaskedFillBackward0>)

In [None]:
attention_weights = attention_score.softmax(-1)
attention_weights # üòã Masked Attention-Weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5029, 0.4971, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3370, 0.3328, 0.3303, 0.0000, 0.0000, 0.0000],
        [0.2430, 0.2589, 0.2568, 0.2413, 0.0000, 0.0000],
        [0.2027, 0.2051, 0.2038, 0.2072, 0.1813, 0.0000],
        [0.1662, 0.1775, 0.1759, 0.1620, 0.1403, 0.1781]],
       grad_fn=<SoftmaxBackward0>)

## **Dropout**

In [126]:
attention_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5029, 0.4971, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3370, 0.3328, 0.3303, 0.0000, 0.0000, 0.0000],
        [0.2430, 0.2589, 0.2568, 0.2413, 0.0000, 0.0000],
        [0.2027, 0.2051, 0.2038, 0.2072, 0.1813, 0.0000],
        [0.1662, 0.1775, 0.1759, 0.1620, 0.1403, 0.1781]],
       grad_fn=<SoftmaxBackward0>)

In [None]:
dropout = nn.Dropout(0.3)
dropout(attention_weights)

1.4285714285714286

## **Causal Attention (Integrating stuff till now)**

In [263]:
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, dropout=0.5, qkv_bias=False):
        super().__init__()
        self.W_q = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_k = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_v = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        '''x: 3D tensor. (batch, num_tokens, embed_dim)'''
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)
        attn_scores = torch.matmul(Q, K.transpose(-1, -2))
        mask = attn_scores.triu(1).bool()
        attn_scores = attn_scores.masked_fill(mask, -torch.inf)
        attn_weights = torch.softmax(attn_scores, -1)
        attn_weights = self.dropout(attn_weights)
        context_vectors = torch.matmul(attn_weights, V)
        return context_vectors

d_in = d_out = x.shape[-1]
causalAttention = CausalAttention(d_in, d_out)
causalAttention(x)

tensor([[ 0.7859,  0.2289, -0.9346],
        [ 0.3012,  0.4350, -0.2510],
        [ 0.6503,  0.6423, -0.6436],
        [ 0.0915,  0.1625, -0.0437],
        [ 0.1474,  0.2837, -0.2065],
        [ 0.1586,  0.2490, -0.1109]], grad_fn=<MmBackward0>)

In [273]:
x

tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]])

In [276]:
x.unsqueeze(0).shape

torch.Size([1, 6, 3])

In [None]:
# Improvised Version (Considering Various Factors):
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, contextLength, dropout=0.5, qkv_bias=False):
        super().__init__()
        self.W_q = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_k = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_v = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.ones(contextLength, contextLength).triu(1).bool())
    
    def forward(self, x):
        '''x: 3D tensor. (batch, num_tokens, embed_dim)'''
        b, num_tokens, embed_dim = x.shape
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)
        attn_scores = torch.matmul(Q, K.transpose(-1, -2))
        attn_scores = attn_scores.masked_fill(mask[:num_tokens, :num_tokens], -torch.inf)
        attn_weights = torch.softmax(attn_scores, -1)
        attn_weights = self.dropout(attn_weights)
        context_vectors = torch.matmul(attn_weights, V)
        return context_vectors

d_in = d_out = x.shape[-1]
causalAttention = CausalAttention(d_in, d_out, 30)
causalAttention(x.unsqueeze(0)) # x.unsqueeze(0) b/c x is 2D, and we expected a 3D tensor, so adding a extra dimension, in order to give it a batch size of 1
# x

tensor([[[-0.3677,  1.2444,  0.2321],
         [-0.1909,  0.6462,  0.1205],
         [-0.0668,  0.5810, -0.0445],
         [-0.0911,  0.3084,  0.0575],
         [-0.1181,  0.5126,  0.0071],
         [-0.0893,  0.5791, -0.0138]]], grad_fn=<UnsafeViewBackward0>)

# **üòãüòã Multihead - Attention**

## **General Overview**

In [346]:
x = torch.tensor([
    [0.43, 0.15, 0.89],
    [0.55, 0.87, 0.66],
    [0.57, 0.85, 0.64],
    [0.22, 0.58, 0.33],
    [0.77, 0.25, 0.10],
    [0.05, 0.80, 0.55]
])

batch = torch.stack([x, x])
batch.shape

torch.Size([2, 6, 3])

In [None]:
b, num_tokens, token_embed = batch.shape # (2, 6, 3)

d_in = 3
d_out = 8
num_heads = 4
# Each [Attention-Head] outputing 2 outputs
d_head = d_out // num_heads

W_q = nn.Linear(d_in, d_out, bias=False)
W_k = nn.Linear(d_in, d_out, bias=False)
W_v = nn.Linear(d_in, d_out, bias=False)

Q = W_q(batch)  # (2, 6, 8) = (batch_size, num_tokens, token_embedOut)
K = W_k(batch)
V = W_v(batch)

Q = Q.view((b, num_tokens, num_heads, d_head)) # (2, 6, 4, 2) = (batch_size, num_tokens, num_heads, d_head)
K = K.view((b, num_tokens, num_heads, d_head)) 
V = V.view((b, num_tokens, num_heads, d_head)) 

Q = Q.transpose(1, 2) # (2, 4, 6, 2) = (batch_size, num_heads, num_tokens, d_head)
K = K.transpose(1, 2)
V = V.transpose(1, 2)

torch.Size([2, 4, 6, 2])

In [None]:
attn_scores = torch.matmul(Q, K.transpose(-1, -2)) # (2, 4, 6, 6)

# Right now, `num_tokens = 6` & Let's set the context-length to be 10.
context_length = 10 # i.e A batch can have this no. of max num_tokens, in each sample
mask = torch.ones(context_length, context_length).triu(1).bool()
mask

tensor([[False,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False, False, False,  True,  True,  True,  True],
        [False, False, False, False, False, False, False,  True,  True,  True],
        [False, False, False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False, False, False, False,  True],
        [False, False, False, False, False, False, False, False, False, False]])

In [366]:
masked_attn_scores = attn_scores.masked_fill(mask[: num_tokens, : num_tokens], -torch.inf)
masked_attn_scores 

tensor([[[[-0.0132,    -inf,    -inf,    -inf,    -inf,    -inf],
          [ 0.0570,  0.0728,    -inf,    -inf,    -inf,    -inf],
          [ 0.0539,  0.0612,  0.0605,    -inf,    -inf,    -inf],
          [ 0.0470,  0.0678,  0.0669,  0.0378,    -inf,    -inf],
          [-0.0184, -0.1650, -0.1611, -0.1117, -0.0464,    -inf],
          [ 0.0790,  0.1816,  0.1783,  0.1110,  0.0688,  0.1488]],

         [[ 0.0330,    -inf,    -inf,    -inf,    -inf,    -inf],
          [ 0.0895,  0.1253,    -inf,    -inf,    -inf,    -inf],
          [ 0.0928,  0.1316,  0.1219,    -inf,    -inf,    -inf],
          [ 0.0416,  0.0516,  0.0463,  0.0409,    -inf,    -inf],
          [ 0.1259,  0.2086,  0.1993,  0.1374, -0.0235,    -inf],
          [ 0.0132, -0.0045, -0.0089,  0.0076, -0.0852,  0.0494]],

         [[-0.0204,    -inf,    -inf,    -inf,    -inf,    -inf],
          [-0.0255, -0.0348,    -inf,    -inf,    -inf,    -inf],
          [-0.0246, -0.0330, -0.0304,    -inf,    -inf,    -inf],
      

In [367]:
masked_attn_weights = masked_attn_scores.softmax(-1)
masked_attn_weights

tensor([[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.4961, 0.5039, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.3318, 0.3342, 0.3340, 0.0000, 0.0000, 0.0000],
          [0.2480, 0.2532, 0.2530, 0.2458, 0.0000, 0.0000],
          [0.2167, 0.1872, 0.1879, 0.1974, 0.2107, 0.0000],
          [0.1585, 0.1757, 0.1751, 0.1637, 0.1569, 0.1700]],

         [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.4911, 0.5089, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.3258, 0.3387, 0.3355, 0.0000, 0.0000, 0.0000],
          [0.2491, 0.2516, 0.2503, 0.2490, 0.0000, 0.0000],
          [0.1986, 0.2157, 0.2137, 0.2009, 0.1710, 0.0000],
          [0.1696, 0.1666, 0.1658, 0.1686, 0.1537, 0.1758]],

         [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.5023, 0.4977, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.3349, 0.3321, 0.3330, 0.0000, 0.0000, 0.0000],
          [0.2500, 0.2496, 0.2500, 0.2505, 0.0000, 0.0000],
          [0.1985, 0.2004, 0.2005, 0

In [392]:
context_vectors = masked_attn_weights @ V
context_vectors.shape # (2, 4, 6, 2) = (batch_size, num_heads, num_tokens, d_head)

torch.Size([2, 4, 6, 2])

In [None]:
context_vectors = context_vectors.transpose(1, 2) # (2, 6, 4, 2) = (batch_size, num_tokens, num_heads, d_head)
# NOTE: The transpose() operation above makes the tensor `context_vectors` non-contigous.
context_vectors = context_vectors.contiguous().view((b, num_tokens, d_out))
# context_vectors = context_vectors.view((b, num_tokens, d_out)) # ‚ùå
context_vectors.shape

torch.Size([2, 6, 8])

In [396]:
# In one-place:
context_vectors = masked_attn_weights @ V
context_vectors.shape # (2, 4, 6, 2) = (batch_size, num_heads, num_tokens, d_head)
context_vectors = context_vectors.transpose(1, 2) # (2, 6, 4, 2) = (batch_size, num_tokens, num_heads, d_head)
context_vectors = context_vectors.contiguous().view((b, num_tokens, d_out))
context_vectors.shape

torch.Size([2, 6, 8])

## **Creating Multihead-Atten Class**

In [36]:
import torch 
import torch.nn as nn

class MultiheadAttention(nn.Module):
    def __init__(self, d_in, d_out, n_heads, context_length, dropout=0.5, qkv_bias=False):
        super().__init__()
        assert (d_out % n_heads == 0)

        self.d_in = d_in
        self.d_out = d_out
        self.W_q = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_k = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_v = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.d_head = (d_out // n_heads)
        self.n_heads = n_heads
        self.dropout = nn.Dropout(dropout)
        self.out_proj = nn.Linear(d_out, d_out)
        self.register_buffer('mask', torch.ones(context_length, context_length).triu(1).bool())
    
    def forward(self, x):
        '''x: 3D. x => (batch_size, num_tokens, token_embed)'''
        b, n_tokens, token_embed = x.shape
        assert self.d_in == token_embed
        
        Q = self.W_q(x) # (b, n_tokens, d_out)
        K = self.W_k(x)
        V = self.W_v(x)

        Q = Q.view(b, n_tokens, self.n_heads, self.d_head) # (b, n_tokens, n_heads, d_head)
        K = K.view(b, n_tokens, self.n_heads, self.d_head) 
        V = V.view(b, n_tokens, self.n_heads, self.d_head) 

        Q = Q.transpose(1, 2) # (b, n_heads, n_tokens, d_head)
        K = K.transpose(1, 2)
        V = V.transpose(1, 2)

        attn_scores = torch.matmul(Q, K.transpose(-1, -2)) / self.d_head**0.5 #K.shape[-1]**0.5
        attn_scores = attn_scores.masked_fill(self.mask[: n_tokens, : n_tokens], -torch.inf)
        attn_weights = attn_scores.softmax(-1)
        attn_weights = self.dropout(attn_weights)
        context_vectors = attn_weights @ V
        context_vectors = context_vectors.transpose(1, 2)
        context_vectors = context_vectors.contiguous().view(b, n_tokens, self.d_out)
        return self.out_proj(context_vectors)

x = torch.rand((2, 5, 4))
d_in = x.shape[-1]
d_out = 8 
n_heads = 4
# d_head = d_out / n_heads = 4 / 2 = 2
mha = MultiheadAttention(d_in, d_out, n_heads, context_length=20)
mha(x)

tensor([[[-0.1669,  0.0420, -0.0686,  0.1521, -0.0018,  0.2404, -0.2778,
          -0.3507],
         [-0.4062,  0.7697, -0.3177,  0.6036, -0.0229,  0.6004,  0.4584,
          -0.3751],
         [-0.2924,  0.5159, -0.2394,  0.3982,  0.0052,  0.5141,  0.2366,
          -0.3548],
         [-0.0708,  0.1552, -0.2544,  0.1053, -0.0647,  0.3601, -0.1160,
          -0.6176],
         [-0.1087,  0.2616, -0.2624,  0.1804,  0.0019,  0.4438,  0.0662,
          -0.5143]],

        [[-0.5400,  0.6497, -0.0971,  0.6654,  0.5507,  0.2037,  0.2026,
          -0.1828],
         [ 0.0379,  0.1156, -0.2986,  0.0354,  0.0167,  0.3991, -0.0222,
          -0.6254],
         [-0.4015,  0.5563, -0.1378,  0.4881,  0.1133,  0.4257,  0.1552,
          -0.2579],
         [-0.1357,  0.6337, -0.3312,  0.3105, -0.0815,  0.5367,  0.3753,
          -0.4939],
         [-0.0886,  0.4219, -0.2576,  0.2222,  0.0583,  0.4556,  0.2306,
          -0.3500]]], grad_fn=<ViewBackward0>)