In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


def create_token_embeddings(sentence):
    torch.manual_seed(123123)
    tokens = sentence.split()

    word_to_idx = {word: i for i, word in enumerate(sorted(set(tokens)), 1)}
    token_ids = [word_to_idx[word] for word in tokens]
    token_tensor = torch.tensor(token_ids)
    embedding = nn.Embedding(len(word_to_idx) + 1, 16)
    token_embeddings = embedding(token_tensor)

    return token_embeddings


sentence = "Attention is all you need for now."
embedding = create_token_embeddings(sentence)
print(embedding)
print(embedding.shape)

In [None]:
def compute_context_vector(X):
    torch.manual_seed(123123)

    dq = dk = 24
    dv = 28
    q = torch.randn(dq, X.size(-1))
    k = torch.randn(dk, X.size(-1))
    v = torch.randn(dv, X.size(-1))

    qx = torch.matmul(X, q.t())
    kx = torch.matmul(X, k.t())
    vx = torch.matmul(X, v.t())


    wx = torch.matmul(qx, kx.T)

    ax = F.softmax(wx / (dk**0.5), dim=-1)

    C = torch.matmul(ax, vx)

    return C


sentence = "Attention is all you need for now."
embedding = create_token_embeddings(sentence)
context_vector = compute_context_vector(embedding)
print(context_vector)
print(context_vector.shape)

In [None]:
def multi_head_attention(X, h):
    num_heads = h
    torch.manual_seed(123123)
    dq = dk = 24
    dv = 28

    context_vectors = []

    for _ in range(num_heads):
        q = torch.randn(dq, X.size(-1))
        k = torch.randn(dk, X.size(-1))
        v = torch.randn(dv, X.size(-1))


        qx = torch.matmul(X, q.t())
        kx = torch.matmul(X, k.t())
        vx = torch.matmul(X, v.t())

        wx = torch.matmul(qx, kx.T)

        ax = F.softmax(wx / (dk**0.5), dim=-1)

        C = torch.matmul(ax, vx)
        context_vectors.append(C)

    context_matrix = torch.stack(context_vectors, dim=-1)

    return context_matrix


sentence = "Attention is all you need for now."
embedding = create_token_embeddings(sentence)
context_matrix = multi_head_attention(embedding, h=5)
print(context_matrix)
print(context_matrix.shape)

- Benefits
  - Intuitively, I think that Cross-attention would allow the model t
    selectively attend to different parts of the input sequences based on the
    interactions between the queries and keys of different input vectors,
   thereby enabling the model to capture more complex relationships between the
   elements of different input sequences.
  - I also think that Cross-attention would enable the model to effectively fuse
    information from different input sequences. By considering the interactions
    between queries and keys from different input vectors, the model may be able
    to generate context vectors that incorporate relevant information from both
    input sequences leading to richer representations. 

- Use cases
  - In machine translation to align words or tokens from the source language to
    the target language.
  - In document summarization tasks, to generate summaries by selectively
    attending to important information in the document based on its relevance to
    the summary.
  - In question answering tasks, to find relevant passages or sentences in a
    document that contain the answer to a given question.