In [26]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



In [28]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
        
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
       
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        
       
        self.out_proj = nn.Linear(embed_dim, embed_dim)
    
    def forward(self, query, key, value, mask=None):
      
        seq_len, batch_size, embed_dim = query.size()
        
        
        Q = self.q_proj(query)  
        K = self.k_proj(key)
        V = self.v_proj(value)
        
       
        Q = Q.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim)
        K = K.contiguous().view(-1, batch_size, self.num_heads, self.head_dim)
        V = V.contiguous().view(-1, batch_size, self.num_heads, self.head_dim)
        
        
        Q = Q.permute(1, 2, 0, 3).reshape(batch_size * self.num_heads, seq_len, self.head_dim)
        K = K.permute(1, 2, 0, 3).reshape(batch_size * self.num_heads, -1, self.head_dim)
        V = V.permute(1, 2, 0, 3).reshape(batch_size * self.num_heads, -1, self.head_dim)
        
       
        scores = torch.bmm(Q, K.transpose(1, 2)) 
        scores = scores / math.sqrt(self.head_dim)
        
        
        if mask is not None:
           
            mask = mask.repeat_interleave(self.num_heads, dim=0)
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attn = F.softmax(scores, dim=-1)  
        
        
        output = torch.bmm(attn, V) 
        
        
        output = output.view(batch_size, self.num_heads, seq_len, self.head_dim)
        output = output.permute(2, 0, 1, 3).contiguous() 
        output = output.view(seq_len, batch_size, embed_dim)  
        
        
        output = self.out_proj(output)
        
        return output


In [29]:


x = torch.arange(6)       # tensor([0, 1, 2, 3, 4, 5])
y = x.view(2, 3)         # reshape without copying data

print(x.storage().data_ptr())  # Memory address of x's data
print(y.storage().data_ptr())  # Same address — they share data!


3486104227904
3486104227904


In [30]:
embed_dim = 64
num_heads = 8
seq_len = 10
batch_size = 2

mha = MultiHeadAttention(embed_dim, num_heads)
x = torch.rand(seq_len, batch_size, embed_dim)

out = mha(x, x, x)  
print(out.shape)

torch.Size([10, 2, 64])
