In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy



class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        # Ensure that the model dimension (d_model) is divisible by the number of heads
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        # Initialize dimensions
        self.d_model = d_model # Model's dimension
        self.num_heads = num_heads # Number of attention heads
        self.d_k = d_model // num_heads # Dimension of each head's key, query, and value
        
        # Linear layers for transforming inputs
        self.W_q = nn.Linear(d_model, d_model) # Query transformation
        self.W_k = nn.Linear(d_model, d_model) # Key transformation
        self.W_v = nn.Linear(d_model, d_model) # Value transformation
        self.W_o = nn.Linear(d_model, d_model) # Output transformation
        
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        # Calculate attention scores
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        # Apply mask if provided (useful for preventing attention to certain parts like padding)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        
        # Softmax is applied to obtain attention probabilities
        attn_probs = torch.softmax(attn_scores, dim=-1)
        
        # Multiply by values to obtain the final output
        output = torch.matmul(attn_probs, V)
        return output
        
    def split_heads(self, x):
        # Reshape the input to have num_heads for multi-head attention
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
        
    def combine_heads(self, x):
        # Combine the multiple heads back to original shape
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
        
    def forward(self, Q, K, V, mask=None):
        # Apply linear transformations and split heads
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))
        
        # Perform scaled dot-product attention
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        
        # Combine heads and apply output transformation
        output = self.W_o(self.combine_heads(attn_output))
        return output
    


class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))
    
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super(PositionalEncoding, self).__init__()
        
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe.unsqueeze(0))
        
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]


class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask):
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x
    

class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_output, src_mask, tgt_mask):
        attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x



In [2]:
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):
        super(Transformer, self).__init__()
        self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)
        self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)

        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])

        self.fc = nn.Linear(d_model, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def generate_mask(self, src, tgt):
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
        tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
        seq_length = tgt.size(1)
        nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool()
        tgt_mask = tgt_mask & nopeak_mask
        return src_mask, tgt_mask

    def forward(self, src, tgt):
        src_mask, tgt_mask = self.generate_mask(src, tgt)
        src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))
        tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))


        # print(src_embedded.shape, tgt_embedded.shape)
        # print()

        enc_output = src_embedded
        for enc_layer in self.encoder_layers:
            enc_output = enc_layer(enc_output, src_mask)

        dec_output = tgt_embedded
        for dec_layer in self.decoder_layers:
            dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)

        output = self.fc(dec_output)
        return output

In [3]:
# src_vocab_size = 5000
# tgt_vocab_size = 5000
# d_model = 512
# num_heads = 8
# num_layers = 6
# d_ff = 2048
# max_seq_length = 100
# dropout = 0.1

src_vocab_size = 200
tgt_vocab_size = 200
d_model = 512
num_heads = 8
num_encoder_layers = 6
num_decoder_layers = 6
num_layers = num_encoder_layers
d_ff = 1024
max_seq_length = 50
dropout = 0

transformer = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)

# Generate random sample data
src_data = torch.randint(1, src_vocab_size, (32, max_seq_length))  # (batch_size, seq_length)
tgt_data = torch.randint(1, tgt_vocab_size, (32, max_seq_length))  # (batch_size, seq_length)

In [4]:
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

transformer.train()

for epoch in range(5):
    optimizer.zero_grad()
    output = transformer(src_data, tgt_data[:, :-1])
    print(src_data.shape, tgt_data[:, :-1].shape)
    print(output.shape)
    print("####")
    loss = criterion(output.contiguous().view(-1, tgt_vocab_size), tgt_data[:, 1:].contiguous().view(-1))

    print(output.contiguous().view(-1, tgt_vocab_size).shape, tgt_data[:, 1:].contiguous().view(-1).shape)

    loss.backward()
    optimizer.step()
    print(f"Epoch: {epoch+1}, Loss: {loss.item()}")
    print()

torch.Size([32, 50]) torch.Size([32, 49])
torch.Size([32, 49, 200])
####
torch.Size([1568, 200]) torch.Size([1568])
Epoch: 1, Loss: 5.452166557312012

torch.Size([32, 50]) torch.Size([32, 49])
torch.Size([32, 49, 200])
####
torch.Size([1568, 200]) torch.Size([1568])
Epoch: 2, Loss: 5.264164924621582

torch.Size([32, 50]) torch.Size([32, 49])
torch.Size([32, 49, 200])
####
torch.Size([1568, 200]) torch.Size([1568])
Epoch: 3, Loss: 5.169758319854736

torch.Size([32, 50]) torch.Size([32, 49])
torch.Size([32, 49, 200])
####
torch.Size([1568, 200]) torch.Size([1568])
Epoch: 4, Loss: 5.104924201965332

torch.Size([32, 50]) torch.Size([32, 49])
torch.Size([32, 49, 200])
####
torch.Size([1568, 200]) torch.Size([1568])
Epoch: 5, Loss: 5.028228759765625



In [5]:
nn.CrossEntropyLoss?

[0;31mInit signature:[0m
[0mnn[0m[0;34m.[0m[0mCrossEntropyLoss[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mweight[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0msize_average[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mignore_index[0m[0;34m:[0m [0mint[0m [0;34m=[0m [0;34m-[0m[0;36m100[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mreduce[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mreduction[0m[0;34m:[0m [0mstr[0m [0;34m=[0m [0;34m'mean'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mlabel_smoothing[0m[0;34m:[0m [0mfloat[0m [0;34m=[0m [0;36m0.0[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m [0;34m->[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m     
This criterion computes the cross entropy loss between input logits
and target.

It is useful when trainin

In [6]:
transformer.eval()

# Generate random sample validation data
val_src_data = torch.randint(1, src_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)
val_tgt_data = torch.randint(1, tgt_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)

with torch.no_grad():

    val_output = transformer(val_src_data, val_tgt_data[:, :-1])
    val_loss = criterion(val_output.contiguous().view(-1, tgt_vocab_size), val_tgt_data[:, 1:].contiguous().view(-1))
    print(f"Validation Loss: {val_loss.item()}")

Validation Loss: 5.4163031578063965


In [1]:
import torch 


Q_reshaped =  torch.tensor([[[[-4.7772e-01,  6.9994e-01, -2.4317e-01,  1.3753e-01],
          [-2.0450e-01,  6.6614e-01,  3.5755e-01, -3.1147e-01],
          [-9.6026e-01,  2.4673e+00, -2.9600e-01,  2.9330e+00],
          [-8.9694e-01,  2.3048e+00,  9.7365e-01,  2.5556e+00]],

         [[-9.8551e-01, -7.5628e-01,  2.1984e+00, -1.0317e+00],
          [-1.5500e+00,  3.5095e-01, -6.8803e-02, -5.2188e-01],
          [-1.8930e+00,  5.8479e-01,  1.4320e+00,  1.2516e-01],
          [-3.0953e+00,  2.3745e+00,  9.2446e-01,  4.7702e-02]],

         [[-8.3401e-01,  1.4440e+00,  4.0272e-01, -4.4414e-01],
          [-5.0454e-01,  3.9886e-01, -2.7185e-02, -2.6923e-01],
          [-3.8632e-01, -1.6106e+00, -6.3674e-01,  1.9070e-01],
          [-6.7301e-01, -1.8924e+00, -1.4568e+00,  1.3190e-01]],

         [[ 5.0697e-01,  1.3615e+00, -8.3323e-01, -1.5126e-01],
          [ 2.2424e-01,  1.6721e+00, -1.2046e+00,  3.8664e-02],
          [ 2.7266e+00,  1.4596e+00,  1.5861e+00,  1.1367e+00],
          [ 3.3616e+00,  1.9529e+00,  7.4700e-01,  1.5359e+00]]],


        [[[-8.6815e-01,  1.2308e+00, -3.3001e-02, -3.7308e-01],
          [ 4.9939e-01,  1.5917e+00,  2.1246e-01,  7.2611e-02],
          [ 1.1958e-01,  1.7782e+00, -8.0772e-01, -3.1812e-01],
          [-8.7687e-01,  3.1814e-01,  2.2518e-01, -2.2660e-01]],

         [[-1.9160e+00,  3.4028e-01,  1.2422e+00,  5.2760e-01],
          [-6.4376e-01,  5.2752e-01, -7.8839e-01, -1.0246e+00],
          [-2.1223e+00,  1.9347e-01,  2.1142e-01, -1.4601e+00],
          [-1.2929e+00, -5.0758e-01,  2.0995e-01, -1.9992e-02]],

         [[-1.9333e+00,  2.1132e-01, -1.0623e-01, -2.4567e-01],
          [-1.4550e+00, -5.4272e-02, -6.6537e-01,  1.1971e+00],
          [-1.6575e+00, -3.0775e-01,  1.6449e-02,  7.6033e-01],
          [-1.0600e+00, -1.3049e+00,  6.8640e-01, -1.7993e-03]],

         [[ 1.8422e-02,  1.3085e+00, -1.0644e+00, -2.7683e-01],
          [ 2.3151e-01,  7.6883e-01, -1.3854e+00, -5.2912e-01],
          [ 8.1380e-01,  2.1532e+00, -1.1853e+00,  3.6409e-01],
          [-3.3871e-01,  1.6210e+00,  2.8922e-01,  1.8015e+00]]],


        [[[-6.7476e-01, -1.2707e+00,  1.1343e+00, -1.9244e+00],
          [-6.3113e-01,  6.2221e-01,  2.5213e-02, -1.3570e-01],
          [-1.0503e+00,  2.2270e+00,  1.2420e+00,  2.2824e+00],
          [-3.5791e-01,  5.8840e-01,  6.2594e-01, -5.8470e-01]],

         [[ 4.6403e-01, -2.9052e-01,  3.5396e-02,  1.2491e+00],
          [-9.0462e-01, -8.1335e-01,  2.3499e+00, -5.4459e-01],
          [-3.0144e+00,  2.3174e+00,  1.0760e+00,  5.3479e-01],
          [-1.4691e+00,  2.9387e-01,  8.2767e-02, -3.4799e-02]],

         [[-8.6633e-01,  8.0823e-02,  2.3088e-01, -1.6974e+00],
          [-9.7112e-01,  1.4730e+00,  1.0912e-01, -7.1013e-01],
          [-8.1012e-01, -1.8634e+00, -1.7504e+00, -1.3410e-01],
          [-6.4165e-01,  4.2784e-01, -3.2078e-01, -5.3522e-01]],

         [[-1.9658e+00,  7.5742e-01, -1.6647e-01,  2.2219e-01],
          [ 1.6952e-01,  1.2077e+00, -1.1462e+00, -3.7947e-01],
          [ 3.0241e+00,  1.7991e+00,  4.3402e-01,  1.3077e+00],
          [-1.1321e-01,  1.5183e+00, -1.5175e+00, -1.8955e-01]]]],
       )

K_reshaped =  torch.tensor([[[[ 7.8398e-01, -2.8978e+00, -1.9969e+00, -8.1522e-02],
          [-2.4054e-01, -1.1789e+00, -3.2179e-01,  3.0028e-01],
          [ 2.2459e+00,  1.1807e-01, -8.4272e-01, -5.2088e-01],
          [ 3.5401e-01,  8.6656e-01,  4.8169e-01,  1.7585e-01]],

         [[-2.2458e-01, -1.4127e-01, -1.5425e-02, -4.3561e-01],
          [-4.5286e-01, -1.7143e-01,  7.3920e-01, -2.2124e-01],
          [-8.1991e-01,  9.2554e-01,  3.2352e-02,  7.1836e-01],
          [-1.3430e+00,  1.0354e+00,  7.8981e-01, -3.3099e-01]],

         [[ 4.3926e-01,  1.0542e+00,  2.5665e-03, -2.5535e+00],
          [ 8.3438e-01,  4.8703e-01,  3.7392e-01, -3.1019e-01],
          [ 5.9615e-01, -2.5148e-01,  2.7867e-01, -1.7343e+00],
          [ 1.0606e+00, -8.2105e-02, -9.2336e-01, -1.8193e-01]],

         [[ 2.1036e-01,  1.3534e+00, -4.9534e-01, -2.2155e+00],
          [-2.2480e-01,  1.1705e+00, -5.1477e-01, -4.6465e-01],
          [ 8.3970e-01,  8.2966e-01,  1.5627e-01, -1.7664e-01],
          [ 1.2857e+00,  1.8845e+00, -7.1693e-02, -4.5686e-01]]],


        [[[ 1.5785e+00, -8.0012e-01, -1.2165e+00,  1.1637e+00],
          [ 4.7456e-01, -1.2277e+00, -1.7932e+00,  6.6939e-01],
          [ 2.9108e-01, -1.0741e+00, -4.3732e-01, -3.2007e-02],
          [ 1.0022e+00,  3.9241e-02, -3.9030e-01,  8.0718e-01]],

         [[-1.1497e+00,  1.8438e-01,  4.2984e-01,  2.9018e-01],
          [-1.5817e+00, -4.8684e-01, -5.8341e-01, -5.8716e-01],
          [-1.0270e+00, -4.7455e-01,  1.2170e-02,  3.4087e-01],
          [-8.9543e-01,  4.3494e-02,  1.0376e+00,  7.5883e-01]],

         [[-5.1195e-03, -4.5586e-02,  3.8176e-01, -7.6378e-01],
          [ 1.7891e-01, -8.7227e-01,  2.0302e-01, -1.4165e+00],
          [ 6.7352e-01,  5.4276e-02,  1.2021e+00, -4.5426e-01],
          [ 4.1585e-01, -1.5690e+00,  5.2294e-01,  6.9571e-01]],

         [[ 1.0573e+00,  9.9475e-01, -2.7674e-01, -1.8682e+00],
          [ 2.2913e+00, -3.4779e-02,  1.0709e+00, -2.4700e+00],
          [ 6.2046e-01,  1.0307e+00, -2.8338e-01, -1.4275e+00],
          [ 6.6790e-01, -3.8380e-01, -1.8104e-01, -8.4886e-01]]],


        [[[-4.2460e-01, -4.5856e-01,  7.9080e-01,  8.4442e-01],
          [ 4.9698e-01, -2.2869e+00, -1.9085e+00,  1.1746e-02],
          [ 6.7006e-02,  1.4774e+00,  5.7012e-01,  2.6911e-01],
          [-5.2754e-01, -5.6802e-01, -2.3336e-01,  3.9355e-01]],

         [[-2.2514e-01, -8.0630e-01,  1.2966e+00,  8.4586e-01],
          [-5.8476e-01,  7.6748e-02, -1.4493e-01,  7.0403e-02],
          [-1.7032e+00,  1.2534e+00,  6.6030e-01,  1.7503e-01],
          [-8.1305e-01,  4.6589e-02,  6.0970e-01,  2.8478e-01]],

         [[-6.5667e-01, -7.2372e-01,  3.2590e-01,  7.6517e-01],
          [ 1.1380e-01,  6.9370e-01,  2.3814e-01, -2.5993e+00],
          [ 7.3511e-01, -4.4257e-01, -6.8778e-01, -2.2767e-01],
          [ 5.0892e-01,  1.2657e-01,  6.0949e-01, -3.5594e-01]],

         [[-7.2759e-02, -1.0743e+00, -6.6493e-02, -6.7504e-01],
          [ 7.4326e-01,  1.2273e+00, -6.6120e-01, -2.2911e+00],
          [ 1.8186e+00,  1.7584e+00, -2.3756e-01, -5.3247e-01],
          [ 3.0809e-01,  1.0444e+00, -6.8064e-01, -5.4025e-01]]]])


V_reshaped =  torch.tensor([[[[-1.0260,  0.9662, -2.2913,  1.8057],
          [ 0.5957, -0.4363, -0.6107,  0.1757],
          [ 1.1529,  0.6390,  0.4691,  0.1807],
          [ 2.0451, -0.5014,  1.6166, -0.3022]],

         [[-0.0240, -1.3353, -1.6751, -0.9049],
          [-0.4670, -1.8406, -0.9363, -0.2806],
          [ 0.2908,  0.4321,  1.0739, -0.3611],
          [-0.6475, -2.4038,  1.9287,  0.6706]],

         [[ 1.5304, -1.8514, -1.5147,  0.9213],
          [ 0.9723, -0.8812, -0.8274,  1.0724],
          [ 1.3649, -0.4078, -1.0577, -1.9731],
          [ 1.7294, -1.8886, -2.2192, -0.2645]],

         [[ 0.5280,  0.7117, -0.2574,  1.5133],
          [ 0.4975,  0.9409, -0.3728, -0.8591],
          [ 0.0418, -0.3742,  0.1779,  0.3042],
          [-0.5309, -0.6696, -0.5339, -0.4551]]],


        [[[-0.2025, -0.9508, -1.4185, -0.6180],
          [-0.5360,  0.3189,  0.1011,  0.4221],
          [ 0.9998, -0.8083, -0.9948,  0.1297],
          [ 0.0029,  0.2251, -1.0812, -1.1702]],

         [[-0.2939, -0.2418, -0.6054,  1.2000],
          [-0.2660, -1.9972, -1.2478,  1.2041],
          [-1.2195, -0.9846, -0.4558,  0.8347],
          [-0.4921, -0.1789,  1.0497,  1.6185]],

         [[ 0.0906, -1.0008, -0.3126,  0.5912],
          [ 1.2918,  0.2884, -0.8203,  0.6532],
          [ 1.3875, -1.1471, -1.8779,  0.5379],
          [ 0.0668,  0.4564,  0.1974,  0.0822]],

         [[ 0.0505, -0.4071,  0.3731,  0.2199],
          [ 2.5341,  0.5179,  0.5687, -1.4736],
          [ 0.9913,  0.6232, -1.3847, -1.2899],
          [-0.1291,  0.0954, -0.4044,  0.5938]]],


        [[[-0.8220,  0.0050, -0.6552, -0.7446],
          [-1.1864,  0.6515, -2.0828,  1.4967],
          [ 1.8847, -0.8161,  1.8250, -0.6112],
          [ 0.4353, -0.7510, -0.4023, -0.1333]],

         [[ 0.2215, -0.0855,  0.7154,  1.1723],
          [-0.1140, -1.0901, -1.3602, -0.3474],
          [-0.7375, -2.1587,  2.2436,  1.2281],
          [-0.5570, -1.5954, -0.6214,  0.2769]],

         [[-1.2167,  0.5764,  1.2362,  0.0251],
          [ 1.1200, -1.7065, -1.3507,  1.0216],
          [ 1.3190, -1.7436, -2.0552, -0.1642],
          [ 0.5620, -0.7362, -0.6634,  1.1727]],

         [[-0.0334,  0.3178,  0.5097,  0.4712],
          [ 0.5073,  0.2432,  0.0210,  1.4853],
          [-0.5516, -1.1381, -0.2555, -0.4832],
          [ 0.4768,  0.4724, -0.0945, -0.8872]]]])

In [2]:
import math

scale_factor = 1 / math.sqrt(Q_reshaped.size(-1))
scale_factor

0.5

In [3]:
# Q_reshaped.size(-1), Q_reshaped.shape


attn_mask = torch.tensor([[[[0., float('-inf'), float('-inf'), float('-inf')],
          [0., 0., float('-inf'), float('-inf')],
          [0., 0., 0., float('-inf')],
          [0., 0., 0., 0.]]]])



In [4]:
attn_mask = torch.tensor([[[[0., float('-inf'), float('-inf'), float('-inf')],
          [0., 0., float('-inf'), float('-inf')],
          [0., 0., 0., float('-inf')],
          [0., 0., 0., 0.]]]])

attn_mask.shape

torch.Size([1, 1, 4, 4])

In [5]:
L, S = Q_reshaped.size(-2), K_reshaped.size(-2)

attn_bias = torch.zeros(L, S, dtype=Q_reshaped.dtype)
attn_bias = attn_bias.unsqueeze(0).unsqueeze(0)
print(attn_bias.shape)

attn_bias += attn_mask

torch.Size([1, 1, 4, 4])


In [6]:
attn_bias

tensor([[[[0., -inf, -inf, -inf],
          [0., 0., -inf, -inf],
          [0., 0., 0., -inf],
          [0., 0., 0., 0.]]]])

In [7]:
attn_weight = Q_reshaped @ K_reshaped.transpose(-2, -1) * scale_factor

attn_weight

tensor([[[[-0.9642, -0.2954, -0.4285,  0.1722],
          [-1.3896, -0.4724, -0.2599,  0.3112],
          [-3.7753, -0.8509, -1.5718,  1.0857],
          [-4.7673, -1.0236, -1.9470,  1.2991]],

         [[ 0.3718,  1.2146, -0.2810,  1.3091],
          [ 0.2635,  0.3532,  0.6093,  1.2817],
          [ 0.1330,  0.8939,  1.1148,  2.1187],
          [ 0.1623,  0.8337,  2.3999,  3.6650]],

         [[ 1.1455,  0.1479,  0.0111, -0.6471],
          [ 0.4431, -0.0767,  0.0291, -0.2469],
          [-1.1781, -0.7020, -0.1667,  0.1379],
          [-1.3156, -1.0344, -0.2800,  0.3814]],

         [[ 1.3486,  0.9894,  0.7259,  1.6732],
          [ 1.4106,  1.2545,  0.6902,  1.7540],
          [-0.3775, -0.1246,  1.7738,  2.8116],
          [-0.2113,  0.2160,  2.1442,  3.6235]]],


        [[[-1.3746, -1.0568, -0.7742, -0.5550],
          [-0.3296, -1.0248, -0.8298,  0.2693],
          [-0.3108, -0.4454, -0.7559,  0.1240],
          [-1.0882, -0.6811, -0.3441, -0.5686]],

         [[ 1.4763,  0.9152,

In [8]:
attn_weight += attn_bias

attn_weight

tensor([[[[-0.9642,    -inf,    -inf,    -inf],
          [-1.3896, -0.4724,    -inf,    -inf],
          [-3.7753, -0.8509, -1.5718,    -inf],
          [-4.7673, -1.0236, -1.9470,  1.2991]],

         [[ 0.3718,    -inf,    -inf,    -inf],
          [ 0.2635,  0.3532,    -inf,    -inf],
          [ 0.1330,  0.8939,  1.1148,    -inf],
          [ 0.1623,  0.8337,  2.3999,  3.6650]],

         [[ 1.1455,    -inf,    -inf,    -inf],
          [ 0.4431, -0.0767,    -inf,    -inf],
          [-1.1781, -0.7020, -0.1667,    -inf],
          [-1.3156, -1.0344, -0.2800,  0.3814]],

         [[ 1.3486,    -inf,    -inf,    -inf],
          [ 1.4106,  1.2545,    -inf,    -inf],
          [-0.3775, -0.1246,  1.7738,    -inf],
          [-0.2113,  0.2160,  2.1442,  3.6235]]],


        [[[-1.3746,    -inf,    -inf,    -inf],
          [-0.3296, -1.0248,    -inf,    -inf],
          [-0.3108, -0.4454, -0.7559,    -inf],
          [-1.0882, -0.6811, -0.3441, -0.5686]],

         [[ 1.4763,    -inf,

In [32]:
attn_weight = torch.softmax(attn_weight, dim=-1)

attn_weight

tensor([[[[1.0000, 0.0000, 0.0000, 0.0000],
          [0.2855, 0.7145, 0.0000, 0.0000],
          [0.0349, 0.6494, 0.3158, 0.0000],
          [0.0020, 0.0860, 0.0342, 0.8778]],

         [[1.0000, 0.0000, 0.0000, 0.0000],
          [0.4776, 0.5224, 0.0000, 0.0000],
          [0.1721, 0.3684, 0.4595, 0.0000],
          [0.0220, 0.0430, 0.2058, 0.7292]],

         [[1.0000, 0.0000, 0.0000, 0.0000],
          [0.6271, 0.3729, 0.0000, 0.0000],
          [0.1866, 0.3004, 0.5130, 0.0000],
          [0.0944, 0.1250, 0.2658, 0.5149]],

         [[1.0000, 0.0000, 0.0000, 0.0000],
          [0.5390, 0.4610, 0.0000, 0.0000],
          [0.0919, 0.1183, 0.7898, 0.0000],
          [0.0168, 0.0258, 0.1776, 0.7797]]],


        [[[1.0000, 0.0000, 0.0000, 0.0000],
          [0.6671, 0.3329, 0.0000, 0.0000],
          [0.3976, 0.3476, 0.2548, 0.0000],
          [0.1590, 0.2389, 0.3347, 0.2674]],

         [[1.0000, 0.0000, 0.0000, 0.0000],
          [0.3077, 0.6923, 0.0000, 0.0000],
          [0.2333, 0

In [29]:
attn_weight @ V_reshaped

tensor([[[[-1.0260,  0.9662, -2.2913,  1.8057],
          [ 0.1327, -0.0359, -1.0905,  0.6411],
          [ 0.7151, -0.0478, -0.3283,  0.2341],
          [ 1.8837, -0.4538,  1.3778, -0.2403]],

         [[-0.0240, -1.3353, -1.6751, -0.9049],
          [-0.2554, -1.5993, -1.2891, -0.5788],
          [-0.0426, -0.7094, -0.1399, -0.4250],
          [-0.4329, -1.7725,  1.5505,  0.3828]],

         [[ 1.5304, -1.8514, -1.5147,  0.9213],
          [ 1.3223, -1.4896, -1.2584,  0.9776],
          [ 1.2779, -0.8194, -1.0738, -0.5182],
          [ 1.5191, -1.3656, -1.6701, -0.4396]],

         [[ 0.5280,  0.7117, -0.2574,  1.5133],
          [ 0.5139,  0.8174, -0.3106,  0.4195],
          [ 0.1404, -0.1188,  0.0727,  0.2776],
          [-0.3848, -0.5523, -0.3987, -0.2975]]],


        [[[-0.2025, -0.9508, -1.4185, -0.6180],
          [-0.3135, -0.5281, -0.9126, -0.2718],
          [-0.0121, -0.4732, -0.7824, -0.0660],
          [ 0.1751, -0.2853, -0.8234, -0.2669]],

         [[-0.2939, -0.2418,