In [136]:
from torch2onnx import encode, compute_flops, GPT5
from transformers import AutoModel, AutoTokenizer
from thop import profile


model_name = 'xd'
onnx_filename = 'models/model.onnx'

# Tokenize text
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Prepare Model
N = 16
model = GPT5(N, model_name)
model.eval()

GPT5(
  (transformer): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_a

### Flops 1 element

In [77]:
input_ids, attention_mask = encode(tokenizer, "My name is Clara and I live in Berkeley, California.")

macs, params, layer_info = profile(model, inputs=(input_ids, attention_mask), ret_layer_info=True, verbose=False)
total_flops = macs * 2

for layer, (layer_macs, _, _) in layer_info.items():
    layer_flops = 2 * layer_macs
    if layer == 'fc':
        print(f"Слой: {layer}, FLOPs: {layer_flops}")
        
trans = model.transformer

output = trans.embeddings(input_ids, attention_mask)

macs, params, layer_info = profile(trans.encoder.layer[0], inputs=(output,), ret_layer_info=True, verbose=False)
total_flops = macs * 2

for layer, (layer_macs, _, _) in layer_info.items():
    layer_flops = 2 * layer_macs
    if layer == 'attention':
        print(f"Слой: {layer}, FLOPs: {layer_flops}")
        

output = trans.embeddings(input_ids, attention_mask)
output_enc = trans.encoder(output)

macs, params, layer_info = profile(trans.pooler, inputs=(output_enc[0],), ret_layer_info=True, verbose=False)
total_flops = macs * 2

for layer, (layer_macs, _, _) in layer_info.items():
    layer_flops = 2 * layer_macs
    print(f"Слой: {layer}, FLOPs: {layer_flops}")
        

trans = model.transformer

macs, params, layer_info = profile(trans.embeddings, inputs=(input_ids, attention_mask), ret_layer_info=True, verbose=False)
total_flops = macs * 2

for layer, (layer_macs, _, _) in layer_info.items():
    layer_flops = 2 * layer_macs
    if layer == 'position_embeddings':
        print(f"Слой: {layer}, FLOPs: {layer_flops}")

Слой: fc, FLOPs: 344064.0
Слой: attention, FLOPs: 66146304.0
Слой: dense, FLOPs: 1179648.0
Слой: activation, FLOPs: 0.0
Слой: position_embeddings, FLOPs: 0.0


### Flops 20 element

In [78]:
message = "My name is Clara and I live in Berkeley, California."
batch_size = 20
messages = [message] * batch_size  # Список из 10 одинаковых сообщений

inputs = tokenizer(messages, padding=True, truncation=True, return_tensors="pt")
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]


macs, params, layer_info = profile(model, inputs=(input_ids, attention_mask), ret_layer_info=True, verbose=False)
total_flops = macs * 2

for layer, (layer_macs, _, _) in layer_info.items():
    layer_flops = 2 * layer_macs
    if layer == 'fc':
        print(f"Слой: {layer}, FLOPs: {layer_flops}")
        
trans = model.transformer

output = trans.embeddings(input_ids, attention_mask)

macs, params, layer_info = profile(trans.encoder.layer[0], inputs=(output,), ret_layer_info=True, verbose=False)
total_flops = macs * 2

for layer, (layer_macs, _, _) in layer_info.items():
    layer_flops = 2 * layer_macs
    if layer == 'attention':
        print(f"Слой: {layer}, FLOPs: {layer_flops}")
        

output = trans.embeddings(input_ids, attention_mask)
output_enc = trans.encoder(output)

macs, params, layer_info = profile(trans.pooler, inputs=(output_enc[0],), ret_layer_info=True, verbose=False)
total_flops = macs * 2

for layer, (layer_macs, _, _) in layer_info.items():
    layer_flops = 2 * layer_macs
    print(f"Слой: {layer}, FLOPs: {layer_flops}")
        

trans = model.transformer

macs, params, layer_info = profile(trans.embeddings, inputs=(input_ids, attention_mask), ret_layer_info=True, verbose=False)
total_flops = macs * 2

for layer, (layer_macs, _, _) in layer_info.items():
    layer_flops = 2 * layer_macs
    if layer == 'position_embeddings':
        print(f"Слой: {layer}, FLOPs: {layer_flops}")

Слой: fc, FLOPs: 6881280.0
Слой: attention, FLOPs: 1322926080.0
Слой: dense, FLOPs: 23592960.0
Слой: activation, FLOPs: 0.0
Слой: position_embeddings, FLOPs: 0.0


### Память

In [86]:
import numpy as np

def fc_memory(input_size, output_size, batch_size, dtype=np.float32):
    weight_size = input_size * output_size
    bias_size = output_size
    weights_memory = (weight_size + bias_size) * np.dtype(dtype).itemsize
    
    return weights_memory * batch_size

def attention_memory(batch_size, hidden_size, dtype=np.float32):
    dtype_size = np.dtype(dtype).itemsize
    weight_memory = (3 * hidden_size**2 + hidden_size**2) * dtype_size
    return weight_memory * batch_size

def positional_embedding_memory(batch_size, max_seq_len, hidden_size, dtype=np.float32):
    dtype_size = np.dtype(dtype).itemsize 
    embedding_matrix_memory = max_seq_len * hidden_size * dtype_size
    return embedding_matrix_memory * batch_size

def tanh_memory(batch_size, seq_len, hidden_size, dtype=np.float32):
    dtype_size = np.dtype(dtype).itemsize 
    tensor_memory = batch_size * seq_len * hidden_size * dtype_size
    total_memory = 2 * tensor_memory 
    return total_memory

In [90]:
print(model.fc)
print('Memory fc batch 1: ', fc_memory(model.fc.in_features, model.fc.out_features, 1), 'bytes')
print('Memory fc batch 20: ', fc_memory(model.fc.in_features, model.fc.out_features, 20), 'bytes')

Linear(in_features=768, out_features=16, bias=True)
Memory fc batch 1:  49216 bytes
Memory fc batch 20:  984320 bytes


In [93]:
print(model.transformer.pooler.dense)
print('Memory dense batch 1: ', fc_memory(model.transformer.pooler.dense.in_features, model.transformer.pooler.dense.out_features, 1), 'bytes')
print('Memory dense batch 20: ', fc_memory(model.transformer.pooler.dense.in_features, model.transformer.pooler.dense.out_features, 20), 'bytes')

Linear(in_features=768, out_features=768, bias=True)
Memory dense batch 1:  2362368 bytes
Memory dense batch 20:  47247360 bytes


In [105]:
print(model.transformer.pooler.activation)
print('Memory tanh batch 1: ', tanh_memory(1, 1, model.transformer.pooler.dense.out_features), 'bytes')
print('Memory tanh batch 20: ', tanh_memory(20, 1, model.transformer.pooler.dense.out_features), 'bytes')

Tanh()
Memory tanh batch 1:  6144 bytes
Memory tanh batch 20:  122880 bytes


In [114]:
print(trans.embeddings.position_embeddings)
print('Memory PE batch 1: ', positional_embedding_memory(1, trans.embeddings.position_embeddings.num_embeddings, trans.embeddings.position_embeddings.embedding_dim), 'bytes')
print('Memory PE batch 20: ', positional_embedding_memory(20, trans.embeddings.position_embeddings.num_embeddings, trans.embeddings.position_embeddings.embedding_dim), 'bytes')

Embedding(512, 768)
Memory PE batch 1:  1572864 bytes
Memory PE batch 20:  31457280 bytes


In [ ]:
print(trans.embeddings.position_embeddings)
print('Memory PE batch 1: ', positional_embedding_memory(1, trans.embeddings.position_embeddings.num_embeddings, trans.embeddings.position_embeddings.embedding_dim), 'bytes')
print('Memory PE batch 20: ', positional_embedding_memory(20, trans.embeddings.position_embeddings.num_embeddings, trans.embeddings.position_embeddings.embedding_dim), 'bytes')

In [118]:
print(trans.encoder.layer[0].attention)
print('Memory Attn batch 1: ', attention_memory(1, 768), 'bytes')
print('Memory Attn batch 20: ', attention_memory(20, 768), 'bytes')

BertAttention(
  (self): BertSdpaSelfAttention(
    (query): Linear(in_features=768, out_features=768, bias=True)
    (key): Linear(in_features=768, out_features=768, bias=True)
    (value): Linear(in_features=768, out_features=768, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (output): BertSelfOutput(
    (dense): Linear(in_features=768, out_features=768, bias=True)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
)
Memory Attn batch 1:  9437184 bytes
Memory Attn batch 20:  188743680 bytes


# Results Table

| Layer | FLOPS (1, 20) batch    | Memory (1,20) batch | FLOPS/Memory (1,20) | Limited |
|-------|------------------------|---------------------|---------------------|---------|
| Proj  | (344064, 6881280)      | (49216, 984320)     | (6.9, 6.9)          | -       |
| Dense | (1179648, 23592960)    | (2362368, 47247360) | (0.49, 0.49)        | -       |
| Tanh  | (0, 0)                 | (6144, 122880)      | (0, 0)              | -       |
| Attn  | (66146304, 1322926080) | (9437184, 188743680)         | (7,  7)             | -       |
| PE    | (0, 0)                 | (1572864, 31457280) | (0, 0)              | -       |
