In [1]:
import torch
from jax.experimental.pallas.ops.gpu.attention_mgpu import attention

In [8]:
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

In [9]:
input_query = inputs[1]
input_query

tensor([0.5500, 0.8700, 0.6600])

In [10]:
input_1 = inputs[0]
input_1

tensor([0.4300, 0.1500, 0.8900])

In [11]:
torch.dot(input_query,input_1)

tensor(0.9544)

In [12]:
atn_score = torch.empty(inputs.shape[0])
for id,value in enumerate(inputs):
    atn_score[id] = torch.dot(value,input_query)

atn_score

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])

q@k^T computes the dot product between query and keys, giving you attention scores
softmax(q@k^T) converts these scores into a probability distribution (attention weights)
softmax(q@k^T)V multiplies these attention weights with the value vectors and sums them up

What is happening is we are trying to find what is the most important token values for given
token for all tokens. i say when you dont understand it again, you should draw the some pictures and
make what is dot products gives.

In [13]:
attn_scores = torch.empty(6, 6)

for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i, j] = torch.dot(x_i, x_j)

print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [14]:
attn_scores = inputs @ inputs.T
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [15]:
attn_weights = torch.softmax(attn_scores, dim=-1)
print(attn_weights)

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])


[batch,seq,feature] -----> dim = 0 : batch ,dim = 1 : seq ,dim = 2 : features

In [16]:
all_context = attn_weights @ inputs
print(all_context)

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


In [2]:
import torch.nn as nn

In [18]:
class Attention(nn.Module):
    def __init__(self,embed_dimension):
        super().__init__()

        self.embed = embed_dimension

        self.Q = nn.Linear(embed_dimension,embed_dimension)
        self.K = nn.Linear(embed_dimension,embed_dimension)
        self.V = nn.Linear(embed_dimension,embed_dimension)

    def forward(self,inputs):
        batch_size, seq_len, embed_dim = inputs.shape
        queries = self.Q(inputs)
        keys = self.K(inputs)
        values = self.V(inputs)
        keys_transposed = keys.transpose(1, 2)

        attention_scores = torch.bmm(queries, keys_transposed) / torch.sqrt(torch.tensor(self.embed, dtype=torch.float32))

        #nn.Softmax is module like nn.MSELoss
        attention_weights = nn.Softmax(dim=-1)(attention_scores)

        context_vector = torch.bmm(attention_weights, values)
        return context_vector


In [19]:
inp = torch.randn(8,4,256)
atn = Attention(256)
score = atn(inp)

In [20]:
score.shape

torch.Size([8, 4, 256])

In [21]:
score[0][0]

tensor([-0.4144,  0.3002, -0.0115, -0.1468,  0.2303, -0.1717, -0.0537,  0.1635,
        -0.3045, -0.1517, -0.1498,  0.0202, -0.3873,  0.1422, -0.2053, -0.3109,
         0.0830,  0.2286, -0.1070,  0.3913, -0.0706,  0.5189, -0.1951,  0.3627,
        -0.2387,  0.1900, -0.2856,  0.0265, -0.0230,  0.4463,  0.1618, -0.2574,
        -0.2262, -0.4509, -0.0500,  0.2093, -0.1159,  0.3719, -0.1064, -0.1726,
        -0.0714, -0.2970,  0.0770,  0.2337,  0.4028, -0.3590,  0.3574,  0.6720,
         0.5702, -0.0287,  0.2062,  0.4490,  0.1704,  0.3056,  0.0718, -0.1337,
         0.1799, -0.2567,  0.6767, -0.4014,  0.2725, -0.1527,  0.3658, -0.1617,
         0.3129,  0.3069,  0.3944,  0.2447, -0.0680,  0.4160, -0.6410, -0.1520,
         0.0715, -0.1772,  0.4450,  0.0840, -0.2896, -0.3560,  0.4340,  0.0432,
         0.0194, -0.3525,  0.1243,  0.2352,  0.3248, -0.1320,  0.2128,  0.2866,
        -0.1223,  0.2483, -0.7347,  0.2327,  0.0905,  0.6703, -0.1278, -0.0797,
         0.3985, -0.5234, -0.0574, -0.57

How Language Models Are Actually Trained
Language models like GPT are typically trained using a technique called "autoregressive language modeling" or "next-token prediction." Here's how it works:

We take a sequence like "I'm going to school tomorrow"
We don't split it into separate "input" and "output" parts
Instead, we train the model to predict each token based on all previous tokens

So the training pairs look like:

Given "I'm", predict "going"
Given "I'm going", predict "to"
Given "I'm going to", predict "school"
Given "I'm going to school", predict "tomorrow"

Each position in the sequence serves as both input (for later predictions) and target (for the prediction at that position).
Why Causal Masking Is Essential
Now you can see why causal masking is crucial:
Without masking, when trying to predict "school", the model would have access to "tomorrow" in its attention mechanism. This defeats the purpose of predicting the next word, since the model already sees it!
The causal mask ensures that prediction at each position can only use information from previous positions, matching how the model will be used during generation.
Contrast with Traditional Supervised Learning
The approach you initially described:

"Data is 'I'm going to school', label is 'going to school tomorrow'"

This is more like a traditional encoder-decoder setup (e.g., for translation or summarization) where you have distinct input and output sequences. While some language models are trained this way for specific tasks, the fundamental pretraining of models like GPT uses the autoregressive approach I described above.

In [22]:
class AttentionMask(nn.Module):
    def __init__(self,embed_dimension):
        super().__init__()

        self.embed = embed_dimension

        self.Q = nn.Linear(embed_dimension,embed_dimension)
        self.K = nn.Linear(embed_dimension,embed_dimension)
        self.V = nn.Linear(embed_dimension,embed_dimension)

    def forward(self,inputs):
        batch_size, seq_len, embed_dim = inputs.shape
        queries = self.Q(inputs)
        keys = self.K(inputs)
        values = self.V(inputs)
        keys_transposed = keys.transpose(1, 2)

        attention_scores = torch.bmm(queries, keys_transposed) / torch.sqrt(torch.tensor(self.embed, dtype=torch.float32))
        mask = torch.tril(torch.ones(seq_len, seq_len)).to(inputs.device)
        masked_attention_scores = attention_scores.masked_fill_(mask == 0, float('-inf'))
        attention_weights = nn.Softmax(dim=-1)(masked_attention_scores)

        context_vector = torch.bmm(attention_weights, values)
        return context_vector


In [23]:
inp = torch.randn(8,4,256)
atn = AttentionMask(256)
score = atn(inp)

In [24]:
score.shape

torch.Size([8, 4, 256])

In [25]:
score[0][0]

tensor([ 2.8060e-01,  3.0595e-01, -2.3299e-01,  8.4332e-01, -1.5796e+00,
         4.7608e-01,  9.8869e-01,  9.1234e-01,  3.2835e-01, -1.0854e-01,
        -2.5397e-01, -9.7945e-01, -4.6993e-01,  5.0027e-01, -1.4757e-01,
         5.7748e-01,  6.7749e-01, -4.5327e-01,  9.3564e-01,  6.2664e-01,
        -1.1262e+00, -4.3796e-02, -4.2575e-01,  1.8760e-01,  7.8315e-01,
         5.7906e-01, -8.3422e-01, -7.9052e-01,  2.2270e-01,  4.6146e-02,
        -1.1209e+00, -3.9621e-01,  3.6719e-02,  7.5170e-01, -7.0961e-01,
        -4.9098e-01,  1.2279e-01, -2.5331e-01,  3.3222e-01,  2.4698e-01,
         1.1430e+00,  7.5232e-03, -2.7155e-01, -9.2286e-01, -2.1283e-01,
        -1.3296e-01, -1.3904e-01, -2.1041e-01, -1.5030e-01,  3.0007e-01,
        -2.8098e-01, -4.1020e-01, -6.5312e-02, -5.9277e-01, -8.7140e-01,
        -1.4851e+00,  6.8901e-01,  6.2157e-02, -2.5424e-01,  9.6720e-01,
         1.5572e-01, -7.8172e-01, -1.0494e-01, -4.3082e-01,  7.9993e-01,
        -8.9199e-02,  2.2814e-01,  7.5920e-02, -6.6

In [11]:
#I previously didnt include context-length , because i created the mask on the forward pass,
#by that way we dont need a created a mask on full length of th model accept. but those mask
# have to create each forward pass it is slower when training.

class MultiHeadAttention(nn.Module):
    def __init__(self,embed_in,embed_out,context_length ,heads,dropout=0,bias = False):
        super().__init__()

        assert embed_out % heads == 0, "embed_out must be divisible by heads"

        self.heads = heads
        self.d_size = embed_out // heads

        self.Q = nn.Linear(embed_in,embed_out,bias=bias)
        self.K = nn.Linear(embed_in,embed_out,bias=bias)
        self.V = nn.Linear(embed_in,embed_out,bias=bias)
        self.dropout = nn.Dropout(dropout)
        self.projection = nn.Linear(embed_out,embed_out)
        #by giving in mask into register buffer stat_dict() save this as well, and
        # easy when model.to(Device)
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length),
                                diagonal=1))

    def forward(self,text):
        batch,seq_len,_ = text.shape
        Q_text = self.Q(text)
        K_text = self.K(text)
        V_text = self.V(text)

        Q_text = Q_text.view(batch,seq_len,self.heads,self.d_size).transpose(1,2)
        K_text = K_text.view(batch,seq_len,self.heads,self.d_size).transpose(1,2)
        V_text = V_text.view(batch,seq_len,self.heads,self.d_size).transpose(1,2)

        attention_score = Q_text @ K_text.transpose(-2, -1) # we cant use torch.bmm() in 4D
        mask = self.mask.bool()[:seq_len,:seq_len]
        attention_weights = torch.softmax(attention_score.masked_fill_(mask,-torch.inf)/K_text.shape[-1]**0.5,dim=-1)
        attention_weights = self.dropout(attention_weights)

        context_vector = attention_weights @ V_text
        context_vector = context_vector.reshape(batch,seq_len,-1)
        return self.projection(context_vector)




In [12]:
atn = MultiHeadAttention(256 ,256,5,heads=8)
print(atn(torch.randn(8,4,256)))

tensor([[[ 6.2263e-02,  1.7006e-02, -6.6266e-01,  ..., -2.0844e-01,
           2.0626e-01, -2.0393e-01],
         [ 6.9717e-02, -5.2152e-01,  6.1783e-02,  ...,  1.1413e-01,
          -1.4347e-01, -3.0959e-01],
         [ 1.0778e-01,  2.0654e-01, -3.2684e-01,  ...,  1.2353e-01,
          -8.2114e-02,  3.0593e-01],
         [-1.9528e-01,  3.4111e-01, -3.2996e-01,  ..., -1.7057e-01,
          -2.4177e-01,  9.9616e-02]],

        [[-5.8231e-01,  3.3775e-02,  2.7034e-02,  ..., -2.7374e-01,
           4.8551e-02,  2.4596e-01],
         [-3.2613e-01, -2.0954e-01, -3.9494e-02,  ..., -1.4908e-01,
          -2.7056e-01, -1.9840e-02],
         [ 4.8427e-02, -1.8274e-01,  4.8473e-01,  ...,  2.3740e-01,
           5.4003e-02, -1.1858e-01],
         [ 1.1849e-01, -3.5461e-01, -2.8840e-01,  ...,  6.3750e-01,
           2.5891e-01,  2.1848e-01]],

        [[-2.9731e-01, -1.4747e-01, -2.1943e-01,  ...,  7.4954e-02,
          -3.1236e-02,  5.8017e-02],
         [-2.2283e-02, -9.6130e-03,  3.8290e-01,  .