In [1]:
import torch as th
import torch.nn.functional as F
import math
import time
import torch.nn as nn

def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    attn_logits = th.matmul(q, k.transpose(-2, -1))
    attn_logits = attn_logits / math.sqrt(d_k)
    if mask is not None:
        attn_logits = attn_logits.masked_fill(mask == 0, -9e15)
    attention = F.softmax(attn_logits, dim=-1)
    values = th.matmul(attention, v)
    return values

def batched_scaled_dot_product(q,k,v,mask,num_batches=2):
    chunked_queries = q.chunk(num_batches, dim=-2)
    result = None
    for cq in chunked_queries:
        current_result = scaled_dot_product(cq, k, v)
        if result is None:
            result = current_result
        else:
            result = th.cat((result, current_result), dim=-2)
    return result

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
''''batch_size = 10
lenght = 10000
dim=2048

q = th.zeros([batch_size,lenght,dim], device='cuda')
k = th.zeros([batch_size,lenght,dim], device='cuda')
v = th.zeros([batch_size,lenght,dim], device='cuda')'''

#values2, attention = scaled_dot_product(q,k,v)

#th.equal(values1, values2)

"'batch_size = 10\nlenght = 10000\ndim=2048\n\nq = th.zeros([batch_size,lenght,dim], device='cuda')\nk = th.zeros([batch_size,lenght,dim], device='cuda')\nv = th.zeros([batch_size,lenght,dim], device='cuda')"

In [3]:
class MultiheadAttention(nn.Module):

    def __init__(self, input_dim, embed_dim, num_heads, num_chunks):
        super().__init__()
        assert embed_dim % num_heads == 0, "Embedding dimension must be 0 modulo number of heads."

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Stack all weight matrices 1...h together for efficiency
        # Note that in many implementations you see "bias=False" which is optional
        self.qkv_proj = nn.Linear(input_dim, 3*embed_dim)
        self.o_proj = nn.Linear(embed_dim, embed_dim)

        self.embed_dim = embed_dim
        self.num_chunks = num_chunks

        self._reset_parameters()

    def _reset_parameters(self):
        # Original Transformer initialization, see PyTorch documentation
        nn.init.xavier_uniform_(self.qkv_proj.weight)
        self.qkv_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.o_proj.weight)
        self.o_proj.bias.data.fill_(0)

    def forward(self, x, mask=None, batched = True):
        qkv = self.qkv_proj(x)
        batch_size, seq_length, _ = x.size()


        # Separate Q, K, V from linear output
        qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3*self.head_dim)
        qkv = qkv.permute(0, 2, 1, 3) # [Batch, Head, SeqLen, Dims]
        q, k, v = qkv.chunk(3, dim=-1)

        # Determine value outputs
        if batched:
            values = batched_scaled_dot_product(q, k, v, mask=mask, num_batches=self.num_chunks)
        else:
            values = scaled_dot_product(q, k, v)
        values = values.permute(0, 2, 1, 3) # [Batch, SeqLen, Head, Dims]
        values = values.reshape(batch_size, seq_length, self.embed_dim)
        o = self.o_proj(values)
        return o

In [4]:
mha = MultiheadAttention(1, embed_dim=10, num_heads=1, num_chunks = 10).to('cuda')


In [18]:
seq_len = 10000
inpt = th.zeros([10,seq_len,1], device='cuda')
inpt[:5,0,0] = 1

result = th.zeros([10,seq_len,10], device='cuda')
result[:5,-1,0] = 1

In [6]:
inpt[0,0,0]

tensor(1., device='cuda:0')

In [19]:
optimiter = th.optim.Adam(mha.parameters(), lr=1e-3)


In [20]:
num_epochs = 10
start_event = th.cuda.Event(enable_timing=True)
end_event = th.cuda.Event(enable_timing=True)
start_event.record()

# Run some things here


for i in range(num_epochs):

    erg = mha.forward(inpt, batched=True)

    loss = ((erg[:,-1,0].reshape(-1) - result[:,-1,0].reshape(-1))**2).mean()
    optimiter.zero_grad()
    loss.backward()
    optimiter.step()

    if i%10==0:
        print(loss)
end_event.record()
th.cuda.synchronize()  # Wait for the events to be recorded!
elapsed_time_ms = start_event.elapsed_time(end_event)
print(elapsed_time_ms)

tensor(0.3874, device='cuda:0', grad_fn=<MeanBackward0>)
5851.45751953125


In [16]:
num_epochs = 10
start_event = th.cuda.Event(enable_timing=True)
end_event = th.cuda.Event(enable_timing=True)
start_event.record()

# Run some things here


for i in range(num_epochs):

    erg = mha.forward(inpt, batched=True)



    loss = ((erg[:,-1,0].reshape(-1) - result[:,-1,0].reshape(-1))**2).mean()
    optimiter.zero_grad()
    loss.backward()
    optimiter.step()

    if i%10==0:
        print(loss)
end_event.record()
th.cuda.synchronize()  # Wait for the events to be recorded!
elapsed_time_ms = start_event.elapsed_time(end_event)
print(elapsed_time_ms)

tensor(5.6401e-14, device='cuda:0', grad_fn=<MeanBackward0>)
77.80659484863281


In [9]:
num_epochs = 1000
optimiter = th.optim.Adam(mha.parameters(), lr=1e-3)

for i in range(num_epochs):
    h = time.perf_counter()
    erg = mha.forward(inpt, batched=False)
    print(time.perf_counter()-h)
    print('_______________________')
    loss = ((erg[:,-1,0].reshape(-1) - result[:,-1,0].reshape(-1))**2).mean()
    optimiter.zero_grad()
    loss.backward()
    optimiter.step()
    if i%10==0:
        print(loss)

0.0011303320000024542
_______________________
tensor(0.4541, device='cuda:0', grad_fn=<MeanBackward0>)
0.00037948600015624834
_______________________
0.00021827199998369906
_______________________
0.0001945070000601845
_______________________
0.00019128100007037574
_______________________
0.0001674659999935102
_______________________
0.00021006600013606658
_______________________
0.0002517249999982596
_______________________
0.00019421700017119292
_______________________
0.00017040200009432738
_______________________
0.00016793700001471734
_______________________
tensor(0.4115, device='cuda:0', grad_fn=<MeanBackward0>)
0.00039152900012595637
_______________________
0.0001655220000884583
_______________________
0.0001600420000613667
_______________________
0.00016763700000410608
_______________________
0.00016102399990813865
_______________________
0.0001608840000244527
_______________________
0.000160221999976784
_______________________
0.00016156500009856245
_______________________
0.

KeyboardInterrupt: 

In [None]:
erg[0,-1,0]

tensor(1.0000, device='cuda:0', grad_fn=<SelectBackward0>)

In [None]:
result[0,-1,0]

tensor(1., device='cuda:0')