# Self Attention

Self attention is a mechanism that enhances the information content of an input embedding by including information about the input's context.


In [1]:
import torch 
import torch.nn as nn 

In [2]:
sentence = "I love cakes, specially chocolate cakes. Lately, I want to try vanilla cake."

dict_words = {s:i for i, s in enumerate(sorted(sentence.replace(',', '').split()))}
dict_words 

{'I': 1,
 'Lately': 2,
 'cake.': 3,
 'cakes': 4,
 'cakes.': 5,
 'chocolate': 6,
 'love': 7,
 'specially': 8,
 'to': 9,
 'try': 10,
 'vanilla': 11,
 'want': 12}

In [3]:
# assign integer index to each word in the sentence
sentence_indices = torch.tensor([dict_words[word] for word in sentence.replace(',', '').split()])
sentence_indices

tensor([ 1,  7,  4,  8,  6,  5,  2,  1, 12,  9, 10, 11,  3])

In [4]:
embed = torch.nn.Embedding(13,16) 
embedded_sentence = embed(sentence_indices)
embedded_sentence.shape 

torch.Size([13, 16])

### Self attention (Scaled Dot Product Attention)

$W_q, W_k, W_v$ are weight matrices that are adjusted during model training.
Key, Query, Value sequences are obtained by matrix multiplication between the weight matrices **W** and input embeddings **x**.

In [5]:
embedded_sentence.shape 

torch.Size([13, 16])

In [6]:
# dimension of word, and the vectors
d_word = embedded_sentence.shape[1]
d_q = 16
d_k = 16
d_v = 20 

# initialize the weight matrices
W_q = torch.nn.Parameter(torch.randn(d_q, d_word))
W_k = torch.nn.Parameter(torch.randn(d_k, d_word))
W_v = torch.nn.Parameter(torch.randn(d_v, d_word))

W_q.shape, W_k.shape, W_v.shape

(torch.Size([16, 16]), torch.Size([16, 16]), torch.Size([20, 16]))

## Getting attention scores and final weighted values for second element of the sequence

In [7]:
# getting attention-vector for second input element
x_2 = embedded_sentence[1]
query_2 = torch.matmul(W_q, x_2)
key_2 = torch.matmul(W_k, x_2)
value_2 = torch.matmul(W_v, x_2)

query_2.shape, key_2.shape, value_2.shape

(torch.Size([16]), torch.Size([16]), torch.Size([20]))

In [8]:
W_k.shape, embedded_sentence.T.shape

(torch.Size([16, 16]), torch.Size([16, 13]))

In [9]:
# getting keys and values for the remaining sequence too
keys = W_k.matmul(embedded_sentence.T).T 
values = W_v.matmul(embedded_sentence.T).T
keys.shape, values.shape

(torch.Size([13, 16]), torch.Size([13, 20]))

In [10]:
# getting attention scores for the second element
attention_score_2 = query_2.matmul(key_2.T)

# the second element in keys refer to the keys corresponding to the second element 
attention_score_2, query_2.matmul(keys[1].T)

  attention_score_2 = query_2.matmul(key_2.T)


(tensor(-35.4803, grad_fn=<DotBackward0>),
 tensor(-35.4803, grad_fn=<DotBackward0>))

In [11]:
# attention score for second element with respect to all words 
attention_scores_2 = query_2.matmul(keys.T) 
attention_scores_2

tensor([ 29.3837, -35.4803,  10.7918,  78.3755,  47.9072,  49.6510, -30.6851,
         29.3837,  21.4900,  -0.9740, -71.2418,  41.8499, -28.4696],
       grad_fn=<SqueezeBackward4>)

In [12]:
# normalize with softmax, and scale with sqrt(d_k)
# change attention scores to probabilities
attention_weights_2 = torch.nn.functional.softmax(attention_scores_2/torch.sqrt(torch.tensor(d_k))) 
attention_weights_2

  attention_weights_2 = torch.nn.functional.softmax(attention_scores_2/torch.sqrt(torch.tensor(d_k)))


tensor([4.7884e-06, 4.3418e-13, 4.5877e-08, 9.9863e-01, 4.9131e-04, 7.5977e-04,
        1.4398e-12, 4.7884e-06, 6.6548e-07, 2.4218e-09, 5.6874e-17, 1.0807e-04,
        2.5052e-12], grad_fn=<SoftmaxBackward0>)

In [13]:
torch.nn.functional.softmax(attention_scores_2/torch.sqrt(torch.tensor(d_k)), dim=0) 

tensor([4.7884e-06, 4.3418e-13, 4.5877e-08, 9.9863e-01, 4.9131e-04, 7.5977e-04,
        1.4398e-12, 4.7884e-06, 6.6548e-07, 2.4218e-09, 5.6874e-17, 1.0807e-04,
        2.5052e-12], grad_fn=<SoftmaxBackward0>)

In [14]:
attention_weights_2.sum() 

tensor(1.0000, grad_fn=<SumBackward0>)

In [15]:
attention_weights_2.shape

torch.Size([13])

In [16]:
# weighted values
# input sequence with acquired attention
weighted_values_2 = attention_weights_2.matmul(values)
weighted_values_2.shape 

torch.Size([20])

In [43]:
def process_sentences(sentences, embedding_shape, max_sentences):
    dict_words = {s:i for i, s in enumerate(sorted(sentences.replace(',', '').split()))}
    # print(f'Dict words: {dict_words}')
    sentence_indices = [dict_words[word] for word in sentences.replace(',', '').split()]
    sentence_indices = sentence_indices[:max_sentences] + [0] * (max_sentences - len(sentence_indices))  # Pad/truncate
    sentence_indices = torch.tensor(sentence_indices, dtype=torch.long)

    embed = torch.nn.Embedding(len(dict_words)+500,embedding_shape)
    embedded_sentence = embed(sentence_indices)
    return embedded_sentence, dict_words

In [34]:
class SelfAttention(torch.nn.Module):
    def __init__(self, d_word, d_q, d_k, d_v):
        super(SelfAttention, self).__init__()
        self.W_q = torch.nn.Parameter(torch.randn(d_q, d_word))
        self.W_k = torch.nn.Parameter(torch.randn(d_k, d_word))
        self.W_v = torch.nn.Parameter(torch.randn(d_v, d_word))
        
    def forward(self, x):
        query = torch.matmul(self.W_q, x.T)
        key = torch.matmul(self.W_k, x.T)
        value = torch.matmul(self.W_v, x.T)
        
        attention_scores = query.matmul(key.T)
        attention_weights = torch.nn.functional.softmax(attention_scores/torch.sqrt(torch.tensor(d_k)), dim=0)
        # print(f'attention_weights: {attention_weights.shape}')
        # print(f'value shape: {value.shape}')
        weighted_values = attention_weights.matmul(value)
        return weighted_values

In [35]:
class SentimentClassifier(nn.Module):
    def __init__(self, d_word, d_q, d_k, d_v, d_linear):
        super(SentimentClassifier, self).__init__()
        self.attention = SelfAttention(d_word, d_q, d_k, d_v)
        
        # attention values will be taken mean 
        # shape will be (d_word, d_v)
        self.fc = nn.Linear(d_linear, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.attention(x)
        # print(f'After attention: {x.shape}')
        # print(f'After mean, attention: {x.mean(dim=0).shape}')
        x = self.fc(x.mean(dim=0))
        x = self.sigmoid(x)
        return x

In [36]:
train_data = [
    ("I love this product", 1),
    ("This is amazing", 1),
    ("I hate this", 0),
    ("This is terrible", 0),
    ("Absolutely fantastic", 1),
    ("Horrible experience", 0),
    ("Best decision ever", 1),
    ("Worst thing happened", 0),
    ("I really enjoyed using this", 1),
    ("This was a huge disappointment", 0),
    ("Highly recommend this to everyone", 1),
    ("Would not buy again", 0),
    ("Superb quality and performance", 1),
    ("Broke after one use", 0),
    ("Exceeded my expectations", 1),
    ("Not worth the money", 0),
    ("Great value for the price", 1),
    ("Completely useless", 0),
    ("Absolutely love it", 1),
    ("I regret buying this", 0),
    ("Best purchase ever", 1),
    ("The worst experience of my life", 0),
    ("Would purchase again", 1),
    ("This is a waste of time", 0),
    ("Fantastic product, works perfectly", 1),
    ("One of the worst purchases I've made", 0),
    ("So happy with this", 1),
    ("Terrible service and product", 0),
    ("Quality is top-notch", 1),
    ("Completely broke within a week", 0),
    ("Loved the fast delivery and quality", 1),
    ("Would never recommend this to anyone", 0),
    ("A must-have for everyone", 1),
    ("Total waste of money", 0),
    ("Incredible value for the price", 1),
    ("Hated every part of it", 0),
    ("Five stars all the way", 1),
    ("A complete letdown", 0),
    ("The best decision I've made", 1),
    ("Worst experience I've had", 0),
    ("So satisfied with my purchase", 1),
    ("Not worth even a single penny", 0),
    ("Fantastic build and performance", 1),
    ("It failed to work as expected", 0),
    ("Absolutely stunning product", 1),
    ("Terrible quality, do not buy", 0),
    ("Beyond my expectations", 1),
    ("Disappointed with the service", 0),
    ("Great overall experience", 1),
    ("Product broke too soon", 0),
]

train_sentences = [s for s, _ in train_data]

test_sentences = [
    "This is fantastic",
    "I regret buying this",
    "I absolutely love it",
    "Worst experience ever",
    "Best purchase of my life",
]

In [37]:
for sentence, label in train_data:
    print(sentence)
    print(label) 

I love this product
1
This is amazing
1
I hate this
0
This is terrible
0
Absolutely fantastic
1
Horrible experience
0
Best decision ever
1
Worst thing happened
0
I really enjoyed using this
1
This was a huge disappointment
0
Highly recommend this to everyone
1
Would not buy again
0
Superb quality and performance
1
Broke after one use
0
Exceeded my expectations
1
Not worth the money
0
Great value for the price
1
Completely useless
0
Absolutely love it
1
I regret buying this
0
Best purchase ever
1
The worst experience of my life
0
Would purchase again
1
This is a waste of time
0
Fantastic product, works perfectly
1
One of the worst purchases I've made
0
So happy with this
1
Terrible service and product
0
Quality is top-notch
1
Completely broke within a week
0
Loved the fast delivery and quality
1
Would never recommend this to anyone
0
A must-have for everyone
1
Total waste of money
0
Incredible value for the price
1
Hated every part of it
0
Five stars all the way
1
A complete letdown
0
T

In [38]:
train_labels = torch.tensor([label_ for _, label_ in train_data], dtype=torch.float32)
train_labels 

tensor([1., 1., 0., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0.,
        1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0.,
        1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0.])

In [47]:
max_len = len(train_sentences)

In [48]:
embedding_dim = 16
embedded_sentences, word_dict_ = process_sentences(' '.join(train_sentences), embedding_dim, max_len)
embedded_sentences.shape

torch.Size([50, 16])

In [51]:
d_word = embedded_sentences.shape[1]
d_linear = max_len 
d_q = 16        # query dimension
d_k = 16        # key dimension
d_v = 16        # value dimension

# initialize the model
model = SentimentClassifier(d_word, d_q, d_k, d_v, d_linear)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.BCELoss()
model 

SentimentClassifier(
  (attention): SelfAttention()
  (fc): Linear(in_features=50, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)

In [52]:
epochs = 100
i=0
for epoch in range(epochs):
    for sentence, label in train_data:
        label = torch.tensor([label], dtype=torch.float32)
        optimizer.zero_grad()
        embedded_sentences, word_dict_ = process_sentences(sentence, embedding_dim, max_len)
        # print(f'Embedded sentences: {embedded_sentences.shape}')
        output = model(embedded_sentences)
        # print(type(output), output)
        # print(type(label), label)
        
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()
        i+=1
        if i%10==0:
            print(f'Epoch: {epoch}, Loss: {loss.item()}')

Epoch: 0, Loss: 0.6641170978546143
Epoch: 0, Loss: 0.7257056832313538
Epoch: 0, Loss: 0.6470229625701904
Epoch: 0, Loss: 0.6427892446517944
Epoch: 0, Loss: 0.45382237434387207
Epoch: 1, Loss: 0.6205589175224304
Epoch: 1, Loss: 0.5748776793479919
Epoch: 1, Loss: 0.5451675653457642
Epoch: 1, Loss: 0.7050160765647888
Epoch: 1, Loss: 0.6728764772415161
Epoch: 2, Loss: 0.7525603175163269
Epoch: 2, Loss: 0.7572631239891052
Epoch: 2, Loss: 0.6798258423805237
Epoch: 2, Loss: 0.5232846140861511
Epoch: 2, Loss: 0.6812126636505127
Epoch: 3, Loss: 0.5486624240875244
Epoch: 3, Loss: 0.6258867383003235
Epoch: 3, Loss: 0.6736737489700317
Epoch: 3, Loss: 0.713619589805603
Epoch: 3, Loss: 0.6508697271347046
Epoch: 4, Loss: 0.7275972366333008
Epoch: 4, Loss: 0.5853260159492493
Epoch: 4, Loss: 0.6316661834716797
Epoch: 4, Loss: 0.6792041659355164
Epoch: 4, Loss: 0.7094566226005554
Epoch: 5, Loss: 0.6542940139770508
Epoch: 5, Loss: 0.5303098559379578
Epoch: 5, Loss: 0.7321962118148804
Epoch: 5, Loss: 0.62

In [54]:
test_results = []
with torch.no_grad():
    for sentence in test_sentences:
        embedded_sentence, _ = process_sentences(sentence, d_word, max_len)
        output = model(embedded_sentence)
        prediction = 1 if output.item() > 0.5 else 0
        test_results.append((sentence, prediction))

In [55]:
for sentence, pred in test_results:
    print(f"Sentence: '{sentence}' -> Predicted Sentiment: {pred}")

Sentence: 'This is fantastic' -> Predicted Sentiment: 0
Sentence: 'I regret buying this' -> Predicted Sentiment: 0
Sentence: 'I absolutely love it' -> Predicted Sentiment: 0
Sentence: 'Worst experience ever' -> Predicted Sentiment: 1
Sentence: 'Best purchase of my life' -> Predicted Sentiment: 0
