#### Copyright © 2023 Taeyoung Kim and Mingi Kang. All rights reserved. ####

This is the solution for the homework assignment of the Machine Learning and Optimization lecture for WS2023. 

### Part 1 ### 

In [5]:
# Parameters (from lecture notes)
dim_embd = 768  # Dimension of the embeddings
n_heads = 12    # Number of heads in the multi-head attention mechanism
n_layers = 12   # Number of layers in the model
dim_feedforward = 3072  # Dimension of the feedforward network

In [6]:
# Multi-Head Attention parameters per layer
# Each head has 3 metrics (Query, Key, Value) and its size is dim_embd x dim_embd/n_heads
# Output projection matrix of size (dim_embd x dim_embd)
mha_params_per_head = 3 * (dim_embd * dim_embd // n_heads)
mha_params_per_layer = mha_params_per_head * n_heads
output_proj_params = dim_embd * dim_embd

# Feed-Forward Network parameters per layer
# Two linear layers: (dim_embd x dim_feedforward) and (dim_feedforward x dim_embd)
ffn_params_per_layer = (dim_embd * dim_feedforward) + (dim_feedforward * dim_embd)

# Total non-embedding parameters per layer
total_params_per_layer = mha_params_per_layer + output_proj_params + ffn_params_per_layer

# Total non-embedding parameters in the transformer
total_non_embedding_params = total_params_per_layer * n_layers
total_non_embedding_params

84934656

### Part 2 ###

In [3]:
# # FLOPs Calculation

# # Multi-Head Attention FLOPs per layer
# # Each head has Query, Key, Value
# # Dimension: dim_embd x dim_embd/n_heads
# # Output projection matrix multiplication: dim_embd x dim_embd
# mha_flops_per_head = 3 * 2 * (dim_embd * dim_embd * (dim_embd // n_heads))  # 3 matrix, 2*ijk FLOPs each
# output_proj_flops = 2 * (dim_embd * dim_embd * dim_embd)
# mha_flops_per_layer = mha_flops_per_head * n_heads + output_proj_flops

# # score computation: (dim_embd/n_heads x dim_embd/n_heads) for each head, per token
# # Softmax computation as if it is relative small
# attention_score_flops_per_head = 2 * ((dim_embd // n_heads) * (dim_embd // n_heads) * (dim_embd // n_heads))
# attention_flops_per_layer = attention_score_flops_per_head * n_heads

# # FFN FLOPs per layer
# # 1. linear layer:dim_embd x dim_feedforward
# # 2. linear layer: dim_feedforward x dim_embd
# ffn_flops_per_layer = 2 * (dim_embd * dim_feedforward * dim_embd)  # 2 layers, 2*ijk FLOPs each

# # Total FLOPs
# total_flops_per_layer = mha_flops_per_layer + attention_flops_per_layer + ffn_flops_per_layer

# # Total FLOPs in the transformer per token
# # Assumued that : token is processed at a time in the forward pass
# total_flops_per_token = total_flops_per_layer * n_layers
# total_flops_per_token


87048585216

In [7]:
# Function to calculate FLOPs for matrix 
# matrix A (i x j) and B (j x k)
def calculate_flops(i, j, k):
    return 2 * i * j * k

# Multi-Head Attention FLOPs per layer
# For Query, Key, Value matrix: A = (dim_embd, dim_embd), B = (dim_embd, dim_embd/n_heads)
# For Output projection matrix: A = (dim_embd, dim_embd), B = (dim_embd, dim_embd)
qkv_flops = 3 * calculate_flops(dim_embd, dim_embd, dim_embd // n_heads)  # 3 matrix: Q, K, V
output_proj_flops = calculate_flops(dim_embd, dim_embd, dim_embd)
mha_flops_per_layer = (qkv_flops + output_proj_flops) * n_heads

# score calculation FLOP
# A = (dim_embd/n_heads, dim_embd/n_heads), B = (dim_embd/n_heads, dim_embd/n_heads)
attention_score_flops_per_head = calculate_flops(dim_embd // n_heads, dim_embd // n_heads, dim_embd // n_heads)
attention_flops_per_layer = attention_score_flops_per_head * n_heads

# FFN FLOPs
# 1. lin layer: A = (dim_embd, dim_embd), B = (dim_embd, dim_feedforward)
# 2. lin layer: A = (dim_embd, dim_feedforward), B = (dim_feedforward, dim_embd)
ffn_flops_1 = calculate_flops(dim_embd, dim_embd, dim_feedforward)
ffn_flops_2 = calculate_flops(dim_embd, dim_feedforward, dim_embd)
ffn_flops_per_layer = ffn_flops_1 + ffn_flops_2

# Total FLOP
total_flops_per_layer = mha_flops_per_layer + attention_flops_per_layer + ffn_flops_per_layer

# Total FLOPs in transformer /per token
total_flops_per_token = total_flops_per_layer * n_layers
total_flops_per_token


250123124736

### Part 3 ###

In [8]:
# Total non-embedding parameters from Question 1
total_non_embedding_params = 84934656  # From previous calculation

# Estimated computational cost
# Forward: äquivalent to  number of non-embedding parameters N
# Backward: äquivalten to double the forward pass
# Total passes: 3N
# each operation in matrx  as a separate operation: roughly 6N
estimated_training_cost_per_token = 6 * total_non_embedding_params
estimated_training_cost_per_token

509607936