In [1]:
from datetime import datetime

#import matplotlib.pyplot as plt
import torch
import numpy as np

def starttimer():
    start = datetime.now()
    print(f'{start.strftime("%Y-%m-%d %H:%M:%S")}')
    return start

def endtimer(start):
    end = datetime.now()
    print(f'{end.strftime("%Y-%m-%d %H:%M:%S")}')
    durn = end - start
    print(f'Duration: {durn.total_seconds()}s')
    return end, durn

## Calculating Self-Attention

In [33]:
input_tensor = torch.Tensor([[0.1, 0.2, 0.9],
                             [0.5, 0.3, 0.8],
                             [0.6, 0.4, 0.8]])

print(f'input_tensor \n {input_tensor.numpy()}', end='\n-------------------------------------------------\n')

# in Self-Attention, Q=K=V
# create vectors for Q,K,V and K transposed
Q = K = V = torch.flatten(input=input_tensor)
Kt = K.reshape(-1,1)
print(f'Q \n {Q.numpy()}', end='\n-------------------------------------------------\n')
print(f'K \n {K.numpy()}', end='\n-------------------------------------------------\n')
print(f'V \n {V.numpy()}', end='\n-------------------------------------------------\n')
print(f'Kt \n {Kt.numpy()}', end='\n-------------------------------------------------\n')


input_tensor 
 [[0.1 0.2 0.9]
 [0.5 0.3 0.8]
 [0.6 0.4 0.8]]
-------------------------------------------------
Q 
 [0.1 0.2 0.9 0.5 0.3 0.8 0.6 0.4 0.8]
-------------------------------------------------
K 
 [0.1 0.2 0.9 0.5 0.3 0.8 0.6 0.4 0.8]
-------------------------------------------------
V 
 [0.1 0.2 0.9 0.5 0.3 0.8 0.6 0.4 0.8]
-------------------------------------------------
Kt 
 [[0.1]
 [0.2]
 [0.9]
 [0.5]
 [0.3]
 [0.8]
 [0.6]
 [0.4]
 [0.8]]
-------------------------------------------------


In [58]:
# Calculate Q.K(transposed) and A
QdotKt = torch.zeros(1,1)
QdotKt=Q*Kt
print(f'Q*Kt \n {QdotKt}', end='\n-------------------------------------------------\n')

Q*Kt 
 tensor([[0.0100, 0.0200, 0.0900, 0.0500, 0.0300, 0.0800, 0.0600, 0.0400, 0.0800],
        [0.0200, 0.0400, 0.1800, 0.1000, 0.0600, 0.1600, 0.1200, 0.0800, 0.1600],
        [0.0900, 0.1800, 0.8100, 0.4500, 0.2700, 0.7200, 0.5400, 0.3600, 0.7200],
        [0.0500, 0.1000, 0.4500, 0.2500, 0.1500, 0.4000, 0.3000, 0.2000, 0.4000],
        [0.0300, 0.0600, 0.2700, 0.1500, 0.0900, 0.2400, 0.1800, 0.1200, 0.2400],
        [0.0800, 0.1600, 0.7200, 0.4000, 0.2400, 0.6400, 0.4800, 0.3200, 0.6400],
        [0.0600, 0.1200, 0.5400, 0.3000, 0.1800, 0.4800, 0.3600, 0.2400, 0.4800],
        [0.0400, 0.0800, 0.3600, 0.2000, 0.1200, 0.3200, 0.2400, 0.1600, 0.3200],
        [0.0800, 0.1600, 0.7200, 0.4000, 0.2400, 0.6400, 0.4800, 0.3200, 0.6400]])
-------------------------------------------------


In [77]:
# Calculate Softmax of Q.Kt
softmax = torch.nn.Softmax(dim=0)
A = softmax(QdotKt)
print(f'A \n {A}', end='\n-------------------------------------------------\n')
print(f'sum(A) \n {A.sum()}', end='\n-------------------------------------------------\n')

A 
 tensor([[0.1066, 0.1022, 0.0746, 0.0897, 0.0979, 0.0782, 0.0857, 0.0937, 0.0782],
        [0.1077, 0.1043, 0.0816, 0.0943, 0.1009, 0.0847, 0.0910, 0.0975, 0.0847],
        [0.1155, 0.1199, 0.1532, 0.1338, 0.1245, 0.1482, 0.1385, 0.1291, 0.1482],
        [0.1109, 0.1107, 0.1069, 0.1095, 0.1104, 0.1076, 0.1090, 0.1100, 0.1076],
        [0.1088, 0.1064, 0.0893, 0.0991, 0.1040, 0.0917, 0.0966, 0.1015, 0.0917],
        [0.1143, 0.1176, 0.1400, 0.1272, 0.1208, 0.1368, 0.1304, 0.1240, 0.1368],
        [0.1121, 0.1129, 0.1169, 0.1151, 0.1137, 0.1166, 0.1157, 0.1145, 0.1166],
        [0.1098, 0.1085, 0.0977, 0.1042, 0.1071, 0.0994, 0.1026, 0.1057, 0.0994],
        [0.1143, 0.1176, 0.1400, 0.1272, 0.1208, 0.1368, 0.1304, 0.1240, 0.1368]])
-------------------------------------------------
sum(A) 
 9.0
-------------------------------------------------


In [78]:
# Calculate AV
AV = torch.zeros(1,1)
AV=torch.matmul(A,V)
print(f'AV \n {AV}', end='\n-------------------------------------------------\n')

output_tensor = torch.reshape(AV,(3,3))
print(f'input_tensor \n {input_tensor}', end='\n-------------------------------------------------\n')
print(f'output_tensor \n {output_tensor}', end='\n-------------------------------------------------\n')

AV 
 tensor([0.3864, 0.4115, 0.6495, 0.4989, 0.4386, 0.6076, 0.5325, 0.4676, 0.6076])
-------------------------------------------------
input_tensor 
 tensor([[0.1000, 0.2000, 0.9000],
        [0.5000, 0.3000, 0.8000],
        [0.6000, 0.4000, 0.8000]])
-------------------------------------------------
output_tensor 
 tensor([[0.3864, 0.4115, 0.6495],
        [0.4989, 0.4386, 0.6076],
        [0.5325, 0.4676, 0.6076]])
-------------------------------------------------


## Calculating Cross-Attention

In [33]:
feature_input = torch.Tensor([[0.1, 0.2, 0.9],
                             [0.5, 0.3, 0.8],
                             [0.6, 0.4, 0.8]])

skip_connection = torch.Tensor([[0.8, 0.6, 0.4, 0.4, 0.8, 0.9],
                                [0.8, 0.8, 0.7, 0.4, 0.7, 0.9],
                                [0.7, 0.6, 0.7, 0.4, 0.4, 0.4],
                                [0.5, 0.4, 0.2, 0.1, 0.1, 0.1],
                                [0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
                                [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]])

print(f'skip_connection \n {skip_connection.numpy()}', end='\n-------------------------------------------------\n')
print(f'feature_input \n {feature_input.numpy()}', end='\n-------------------------------------------------\n')

input_tensor 
 [[0.1 0.2 0.9]
 [0.5 0.3 0.8]
 [0.6 0.4 0.8]]
-------------------------------------------------
Q 
 [0.1 0.2 0.9 0.5 0.3 0.8 0.6 0.4 0.8]
-------------------------------------------------
K 
 [0.1 0.2 0.9 0.5 0.3 0.8 0.6 0.4 0.8]
-------------------------------------------------
V 
 [0.1 0.2 0.9 0.5 0.3 0.8 0.6 0.4 0.8]
-------------------------------------------------
Kt 
 [[0.1]
 [0.2]
 [0.9]
 [0.5]
 [0.3]
 [0.8]
 [0.6]
 [0.4]
 [0.8]]
-------------------------------------------------


In [None]:
# in Cross-Attention, Q=K feature input and V = skip connection
# create vectors for Q,K,V and K transposed
Q = K = V = torch.flatten(input=input_tensor)
Kt = K.reshape(-1,1)
print(f'Q \n {Q.numpy()}', end='\n-------------------------------------------------\n')
print(f'K \n {K.numpy()}', end='\n-------------------------------------------------\n')
print(f'V \n {V.numpy()}', end='\n-------------------------------------------------\n')
print(f'Kt \n {Kt.numpy()}', end='\n-------------------------------------------------\n')

In [58]:
# Calculate Q.K(transposed) and A
QdotKt = torch.zeros(1,1)
QdotKt=Q*Kt
print(f'Q*Kt \n {QdotKt}', end='\n-------------------------------------------------\n')

Q*Kt 
 tensor([[0.0100, 0.0200, 0.0900, 0.0500, 0.0300, 0.0800, 0.0600, 0.0400, 0.0800],
        [0.0200, 0.0400, 0.1800, 0.1000, 0.0600, 0.1600, 0.1200, 0.0800, 0.1600],
        [0.0900, 0.1800, 0.8100, 0.4500, 0.2700, 0.7200, 0.5400, 0.3600, 0.7200],
        [0.0500, 0.1000, 0.4500, 0.2500, 0.1500, 0.4000, 0.3000, 0.2000, 0.4000],
        [0.0300, 0.0600, 0.2700, 0.1500, 0.0900, 0.2400, 0.1800, 0.1200, 0.2400],
        [0.0800, 0.1600, 0.7200, 0.4000, 0.2400, 0.6400, 0.4800, 0.3200, 0.6400],
        [0.0600, 0.1200, 0.5400, 0.3000, 0.1800, 0.4800, 0.3600, 0.2400, 0.4800],
        [0.0400, 0.0800, 0.3600, 0.2000, 0.1200, 0.3200, 0.2400, 0.1600, 0.3200],
        [0.0800, 0.1600, 0.7200, 0.4000, 0.2400, 0.6400, 0.4800, 0.3200, 0.6400]])
-------------------------------------------------


In [77]:
# Calculate Softmax of Q.Kt
softmax = torch.nn.Softmax(dim=0)
A = softmax(QdotKt)
print(f'A \n {A}', end='\n-------------------------------------------------\n')
print(f'sum(A) \n {A.sum()}', end='\n-------------------------------------------------\n')

A 
 tensor([[0.1066, 0.1022, 0.0746, 0.0897, 0.0979, 0.0782, 0.0857, 0.0937, 0.0782],
        [0.1077, 0.1043, 0.0816, 0.0943, 0.1009, 0.0847, 0.0910, 0.0975, 0.0847],
        [0.1155, 0.1199, 0.1532, 0.1338, 0.1245, 0.1482, 0.1385, 0.1291, 0.1482],
        [0.1109, 0.1107, 0.1069, 0.1095, 0.1104, 0.1076, 0.1090, 0.1100, 0.1076],
        [0.1088, 0.1064, 0.0893, 0.0991, 0.1040, 0.0917, 0.0966, 0.1015, 0.0917],
        [0.1143, 0.1176, 0.1400, 0.1272, 0.1208, 0.1368, 0.1304, 0.1240, 0.1368],
        [0.1121, 0.1129, 0.1169, 0.1151, 0.1137, 0.1166, 0.1157, 0.1145, 0.1166],
        [0.1098, 0.1085, 0.0977, 0.1042, 0.1071, 0.0994, 0.1026, 0.1057, 0.0994],
        [0.1143, 0.1176, 0.1400, 0.1272, 0.1208, 0.1368, 0.1304, 0.1240, 0.1368]])
-------------------------------------------------
sum(A) 
 9.0
-------------------------------------------------


In [78]:
# Calculate AV
AV = torch.zeros(1,1)
AV=torch.matmul(A,V)
print(f'AV \n {AV}', end='\n-------------------------------------------------\n')

output_tensor = torch.reshape(AV,(3,3))
print(f'input_tensor \n {input_tensor}', end='\n-------------------------------------------------\n')
print(f'output_tensor \n {output_tensor}', end='\n-------------------------------------------------\n')

AV 
 tensor([0.3864, 0.4115, 0.6495, 0.4989, 0.4386, 0.6076, 0.5325, 0.4676, 0.6076])
-------------------------------------------------
input_tensor 
 tensor([[0.1000, 0.2000, 0.9000],
        [0.5000, 0.3000, 0.8000],
        [0.6000, 0.4000, 0.8000]])
-------------------------------------------------
output_tensor 
 tensor([[0.3864, 0.4115, 0.6495],
        [0.4989, 0.4386, 0.6076],
        [0.5325, 0.4676, 0.6076]])
-------------------------------------------------
