<a href="https://colab.research.google.com/github/RxAI-dev/rxnn-notebooks/blob/main/Experimental_Attention_Check.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Install RxNN and dependencies

In [None]:
!pip install rxnn==0.1.59 torch==2.6.0 transformers tokenizers huggingface_hub datasets

In [None]:
!wget https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl

In [None]:
!pip install --no-dependencies --upgrade flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl

## Import libraries

In [None]:
import torch, numpy as np, random, gc, os
from rxnn.experimental.models import ExperimentalAttentionTransformer
from rxnn.training.dataset import AutoregressiveLMDataset
from rxnn.training.bml import AutoregressiveTrainer
from rxnn.training.scheduler import get_transformer_lr_scheduler
from rxnn.training.callbacks import PrintLossCallback, PrintAccuracyCallback, TokenCounterCallback, ModelSaveCallback
from rxnn.training.tokenizer import TokenizerTrainer, load_tokenizer_from_hf_hub
from rxnn.utils import get_model_size
from datetime import datetime

In [None]:
import torch
torch.__version__

In [None]:
import flash_attn
flash_attn.__version__

## Mixture-of-Experts Attention - Test on micro-size models (~2.5M Params) with tinyStories dataset

> Unfortunately Mixture-of-Experts Attention has worse results than classic GQA/MQA, so we abandoned this study and proceed to **Sparse Query Attention** research.



In [None]:
embed_dim = 128
vocab_size = 5_000
seq_len = 256

torch.random.manual_seed(2137)
np.random.seed(2137)
random.seed(2137)

base_config = {
    'num_layers': 6,
    'vocab_size': vocab_size,
    'embed_dim': embed_dim,
    'att_heads': 8,
    'att_groups': 2,
    'seq_len': seq_len,
    'use_flash_attention': False,
    'use_gated': True,
    'ff_dropout': 0.1,
    'ff_activation': 'silu',
    'ff_dim': 384,
}

gqa_config = {
    **base_config,
    'att_type': 'gqa',
}

mqa_config = {
    **base_config,
    'att_type': 'mqa',
}

mha_config = {
    **base_config,
    'att_type': 'mha',
}


gma_config = {
    **base_config,
    'att_type': 'gma',
}

dma_config = {
    **base_config,
    'att_type': 'dma',
    'att_num_query_groups': 4
}


gqa_decoder = ExperimentalAttentionTransformer(**gqa_config)
mqa_decoder = ExperimentalAttentionTransformer(**mqa_config)
mha_decoder = ExperimentalAttentionTransformer(**mha_config)
gma_decoder = ExperimentalAttentionTransformer(**gma_config)
dma_decoder = ExperimentalAttentionTransformer(**dma_config)

(('GQA', gqa_decoder.params_count()),
('MQA', mqa_decoder.params_count()),
('MHA', mha_decoder.params_count()),
('GMA', gma_decoder.params_count()),
('DMA', dma_decoder.params_count()))

### Computational efficiency

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.cuda.empty_cache()
print(device)

inp = torch.ones(256, 256, dtype=torch.int32).to(device)
gqa_decoder = gqa_decoder.to(device)
mha_decoder = mha_decoder.to(device)
mqa_decoder = mqa_decoder.to(device)
dma_decoder = dma_decoder.to(device)
gma_decoder = gma_decoder.to(device)

steps = 100
warmup = 100
with torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16):

  print(inp.dtype)
  for _ in range(warmup):
    _ = dma_decoder(inp)

  t1 = datetime.timestamp(datetime.now())
  for _ in range(steps):
    _ = dma_decoder(inp)
  t2 = datetime.timestamp(datetime.now())
  print('DMA: ', t2 - t1, (t2 - t1) / steps)

  t1 = datetime.timestamp(datetime.now())
  for _ in range(steps):
    _ = gma_decoder(inp)
  t2 = datetime.timestamp(datetime.now())
  print('GMA: ', t2 - t1, (t2 - t1) / steps)

  t1 = datetime.timestamp(datetime.now())
  for _ in range(steps):
    _ = gqa_decoder(inp)
  t2 = datetime.timestamp(datetime.now())
  print('GQA: ', t2 - t1, (t2 - t1) / steps)

  t1 = datetime.timestamp(datetime.now())
  for _ in range(steps):
      _ = mha_decoder(inp)
  t2 = datetime.timestamp(datetime.now())
  print('MHA: ', t2 - t1, (t2 - t1) / steps)

  t1 = datetime.timestamp(datetime.now())
  for _ in range(steps):
    _ = mqa_decoder(inp)
  t2 = datetime.timestamp(datetime.now())
  print('MQA: ', t2 - t1, (t2 - t1) / steps)

### Get tokenizer

In [None]:
tr = TokenizerTrainer.from_pretrained('ReactiveAI/RxT-Alpha-Micro-Decoder')
tokenizer = tr.get_hf_tokenizer()
tokenizer


### Load datasets

In [None]:
from datasets import load_dataset
stories_dataset = load_dataset('roneneldan/TinyStories', split='train', trust_remote_code=True)
hf_valid_dataset = load_dataset('roneneldan/TinyStories', split='validation', trust_remote_code=True)
len(stories_dataset), len(hf_valid_dataset)

### Train MultiHeadAttention model

In [None]:
# selected model: MHA
decoder = mha_decoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 256
epochs = 1
gradient_acc_steps = 1

peak_lr = 5e-4 * gradient_acc_steps

subset_len = int(len(stories_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
warmup_steps = int(0.25 * steps_per_epoch)


logs_dir = './att/tensorboard_logs/mha'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./att/mha', push_to_hub=True,
                            hub_model_id='ReactiveAI/MHA-MAT', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='MHA Ref Transformer v2')

train_dataset = AutoregressiveLMDataset(stories_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(hf_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, gradient_accumulation_steps=gradient_acc_steps, use_moe_aux_loss=False, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

### Train GroupedQueryAttention model

In [None]:
# selected model: GQA
decoder = gqa_decoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 256
epochs = 1
gradient_acc_steps = 1

peak_lr = 5e-4 * gradient_acc_steps

subset_len = int(len(stories_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
# warmup_steps = int(0.1 * total_steps)
warmup_steps = int(0.25 * steps_per_epoch)


logs_dir = './att/tensorboard_logs/gqa'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./att/gqa', push_to_hub=True,
                            hub_model_id='ReactiveAI/GQA-MAT', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='GQA Ref Transformer v2')

train_dataset = AutoregressiveLMDataset(stories_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(hf_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, gradient_accumulation_steps=gradient_acc_steps, use_moe_aux_loss=False, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

### Train MultiQueryAttention model

In [None]:
# selected model: MQA
decoder = mqa_decoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 256
epochs = 1
gradient_acc_steps = 1

peak_lr = 5e-4 * gradient_acc_steps

subset_len = int(len(stories_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
# warmup_steps = int(0.1 * total_steps)
warmup_steps = int(0.25 * steps_per_epoch)


logs_dir = './att/tensorboard_logs/mqa'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./att/mqa', push_to_hub=True,
                            hub_model_id='ReactiveAI/MQA-MAT', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='MQA Ref Transformer v2')

train_dataset = AutoregressiveLMDataset(stories_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(hf_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, gradient_accumulation_steps=gradient_acc_steps, use_moe_aux_loss=False, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

### Train GroupedMoeAttention model

In [None]:
# selected model: GMA
decoder = gma_decoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 256
epochs = 1
gradient_acc_steps = 1

peak_lr = 5e-4 * gradient_acc_steps

subset_len = int(len(stories_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
# warmup_steps = int(0.1 * total_steps)
warmup_steps = int(0.25 * steps_per_epoch)


logs_dir = './att/tensorboard_logs/gma'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./att/gma', push_to_hub=True,
                            hub_model_id='ReactiveAI/GMAT', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='GroupedMoeAttentionTransformer v2')

train_dataset = AutoregressiveLMDataset(stories_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(hf_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, gradient_accumulation_steps=gradient_acc_steps, use_moe_aux_loss=True, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

### Train DeepMoeAttention model

In [None]:
# selected model: DMA
decoder = dma_decoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 256
epochs = 1
gradient_acc_steps = 1

peak_lr = 5e-4 * gradient_acc_steps

subset_len = int(len(stories_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
warmup_steps = int(0.25 * steps_per_epoch)


logs_dir = './att/tensorboard_logs/dma'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./att/dma', push_to_hub=True,
                            hub_model_id='ReactiveAI/DMAT', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='DeepMoeAttentionTransformer v1')

train_dataset = AutoregressiveLMDataset(stories_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(hf_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, gradient_accumulation_steps=gradient_acc_steps, use_moe_aux_loss=True, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

### Train Extended DMA model

In [None]:
# selected model: Extended DMA
ext_config = {
    'num_layers': 6,
    'vocab_size': vocab_size,
    'embed_dim': embed_dim,
    'att_heads': 8,
    'att_groups': 8,
    'seq_len': seq_len,
    'use_flash_attention': False,
    'use_gated': True,
    'ff_dropout': 0.1,
    'ff_activation': 'silu',
    'ff_dim': 384,
    'att_type': 'dma',
    'att_num_experts': 16,
    'att_num_query_experts': 16,
    'att_num_query_groups': 8
}
decoder = ExperimentalAttentionTransformer(**ext_config)
print(decoder.params_count())
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 256
epochs = 1
gradient_acc_steps = 1

peak_lr = 5e-4 * gradient_acc_steps

subset_len = int(len(stories_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
warmup_steps = int(0.25 * steps_per_epoch)


logs_dir = './att/tensorboard_logs/xdma'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./att/xdma', push_to_hub=True,
                            hub_model_id='ReactiveAI/xDMAT', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='ExtendedDeepMoeAttentionTransformer v1')

train_dataset = AutoregressiveLMDataset(stories_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(hf_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, gradient_accumulation_steps=gradient_acc_steps, use_moe_aux_loss=True, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

## Mixture-of-Experts Attention - Test on mini-size models (~11-12M Params) on wikipedia dataset

> For faster results, use only 50% of wikipedia dataset (>3M items) - 45% for training and 5% for validation.

> Unfortunately Mixture-of-Experts Attention has worse results than classic GQA/MQA, so we abandoned this study and proceed to **Sparse Query Attention** research.

> **MHA/GQA/MQA** results from this run were later used to compare with **SQA** variants

In [None]:
embed_dim = 256
vocab_size = 10_000
seq_len = 1024

torch.random.manual_seed(210037)
np.random.seed(210037)
random.seed(210037)

base_config = {
    'num_layers': 8,
    'vocab_size': vocab_size,
    'embed_dim': embed_dim,
    'att_heads': 16,
    'att_groups': 4,
    'seq_len': seq_len,
    'use_flash_attention': False,
    'use_gated': True,
    'ff_dropout': 0.1,
    'ff_activation': 'silu',
    'ff_dim': 768,
}

gqa_config = {
    **base_config,
    'att_type': 'gqa',
}

mqa_config = {
    **base_config,
    'att_type': 'mqa',
}

mha_config = {
    **base_config,
    'att_type': 'mha',
}


gma_config = {
    **base_config,
    'att_type': 'gma',
}

dma_config = {
    **base_config,
    'att_type': 'dma',
    'att_num_query_groups': 8
}


gqa_decoder = ExperimentalAttentionTransformer(**gqa_config)
mqa_decoder = ExperimentalAttentionTransformer(**mqa_config)
mha_decoder = ExperimentalAttentionTransformer(**mha_config)
gma_decoder = ExperimentalAttentionTransformer(**gma_config)
dma_decoder = ExperimentalAttentionTransformer(**dma_config)

(('GQA', gqa_decoder.params_count()),
('MQA', mqa_decoder.params_count()),
('MHA', mha_decoder.params_count()),
('GMA', gma_decoder.params_count()),
('DMA', dma_decoder.params_count()))

#### Use pretrained tokenizer from RxT-Alpha Mini

In [None]:
tokenizer = load_tokenizer_from_hf_hub('ReactiveAI/RxT-Alpha-Mini-Decoder')
tokenizer

### Load and split dataset

In [None]:
from datasets import load_dataset
wiki_dataset = load_dataset('wikimedia/wikipedia', '20231101.en', split='train', trust_remote_code=True)
train_split_len = int(len(wiki_dataset) * 0.45)
valid_split_len = int(len(wiki_dataset) * 0.05)
wiki_train_dataset = wiki_dataset.select(range(train_split_len))
wiki_valid_dataset = wiki_dataset.select(range(train_split_len, train_split_len + valid_split_len))
_rest_dataset = wiki_dataset.select(range(train_split_len + valid_split_len, len(wiki_dataset)))
len(wiki_train_dataset), len(wiki_valid_dataset), len(_rest_dataset)

### Train GQA Mini model




In [None]:
# selected model: GQA
decoder = gqa_decoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 128
epochs = 1
gradient_acc_steps = 1

peak_lr = 5e-4 * gradient_acc_steps

subset_len = int(len(wiki_train_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(decoder.params_count())
print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
# warmup_steps = int(0.1 * total_steps)
warmup_steps = int(0.25 * steps_per_epoch)


logs_dir = './att_mini/tensorboard_logs/gqa'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./att_mini/gqa', push_to_hub=True,
                            hub_model_id='ReactiveAI/GQA-MAT-m', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='GQA Mini Ref Transformer v')

train_dataset = AutoregressiveLMDataset(wiki_train_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(wiki_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, gradient_accumulation_steps=gradient_acc_steps, use_moe_aux_loss=False, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

### Train MHA Mini model

In [None]:
# selected model: MHA
decoder = mha_decoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 128
epochs = 1
gradient_acc_steps = 1

peak_lr = 5e-4 * gradient_acc_steps

subset_len = int(len(wiki_train_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(decoder.params_count())
print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
# warmup_steps = int(0.1 * total_steps)
warmup_steps = int(0.25 * steps_per_epoch)


logs_dir = './att_mini/tensorboard_logs/mha'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./att_mini/mha', push_to_hub=True,
                            hub_model_id='ReactiveAI/MHA-MAT-m', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='MHA Mini Ref Transformer v1')

train_dataset = AutoregressiveLMDataset(wiki_train_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(wiki_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, gradient_accumulation_steps=gradient_acc_steps, use_moe_aux_loss=False, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

### Train MQA Mini model

In [None]:
# selected model: MQA
decoder = mqa_decoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 128
epochs = 1
gradient_acc_steps = 1

peak_lr = 5e-4 * gradient_acc_steps

subset_len = int(len(wiki_train_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(decoder.params_count())
print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
# warmup_steps = int(0.1 * total_steps)
warmup_steps = int(0.25 * steps_per_epoch)


logs_dir = './att_mini/tensorboard_logs/mqa'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./att_mini/mqa', push_to_hub=True,
                            hub_model_id='ReactiveAI/MQA-MAT-m', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='MQA Mini Ref Transformer v1')

train_dataset = AutoregressiveLMDataset(wiki_train_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(wiki_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, gradient_accumulation_steps=gradient_acc_steps, use_moe_aux_loss=False, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

### Train GMA Mini model

In [None]:
# selected model: GMA
decoder = gma_decoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 128
epochs = 1
gradient_acc_steps = 1

peak_lr = 5e-4 * gradient_acc_steps

subset_len = int(len(wiki_train_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(decoder.params_count())
print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
# warmup_steps = int(0.1 * total_steps)
warmup_steps = int(0.25 * steps_per_epoch)


logs_dir = './att_mini/tensorboard_logs/gma'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./att_mini/gma', push_to_hub=True,
                            hub_model_id='ReactiveAI/GMAT-m', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='GroupedMoeAttentionTransformer Mini v1')

train_dataset = AutoregressiveLMDataset(wiki_train_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(wiki_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, gradient_accumulation_steps=gradient_acc_steps, use_moe_aux_loss=True, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

### Train DMA Mini model

In [None]:
# selected model: DMA
decoder = dma_decoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 128
epochs = 1
gradient_acc_steps = 1

peak_lr = 5e-4 * gradient_acc_steps

subset_len = int(len(wiki_train_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(decoder.params_count())
print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
# warmup_steps = int(0.1 * total_steps)
warmup_steps = int(0.25 * steps_per_epoch)


logs_dir = './att_mini/tensorboard_logs/dma'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./att_mini/dma', push_to_hub=True,
                            hub_model_id='ReactiveAI/DMAT-m', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='DeepMoeAttentionTransformer Mini v1')

train_dataset = AutoregressiveLMDataset(wiki_train_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(wiki_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, gradient_accumulation_steps=gradient_acc_steps, use_moe_aux_loss=True, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

## Mixture-of-Experts Attention - Next round

> Additional epoch to finally confirm that **MoE Attention** is not worth using

> Unfortunately Mixture-of-Experts Attention has worse results than classic GQA/MQA, so we abandoned this study and proceed to **Sparse Query Attention** research.

### GQA

In [None]:
next_dataset = _rest_dataset

In [None]:
# selected model: GQA
decoder = MoeAttentionTransformer.from_pretrained('ReactiveAI/GQA-MAT-m')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 128
epochs = 1
gradient_acc_steps = 1

peak_lr = 3e-4 * gradient_acc_steps

subset_len = int(len(next_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(decoder.params_count())
print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
# warmup_steps = int(0.1 * total_steps)
warmup_steps = 0


logs_dir = './att_mini/tensorboard_logs/gqa_rest'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./att_mini/gqa_rest', push_to_hub=True,
                            hub_model_id='ReactiveAI/GQA-MAT-m', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='GQA Mini Ref Transformer v2')

train_dataset = AutoregressiveLMDataset(next_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(wiki_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, gradient_accumulation_steps=gradient_acc_steps, use_moe_aux_loss=False, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

### GMA

In [None]:
# selected model: GMA
decoder = MoeAttentionTransformer.from_pretrained('ReactiveAI/GMAT-m')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 128
epochs = 1
gradient_acc_steps = 1

peak_lr = 3e-4 * gradient_acc_steps

subset_len = int(len(next_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(decoder.params_count())
print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
# warmup_steps = int(0.1 * total_steps)
warmup_steps = 0


logs_dir = './att_mini/tensorboard_logs/gma_rest'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./att_mini/gma_rest', push_to_hub=True,
                            hub_model_id='ReactiveAI/GMAT-m', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='GroupedMoeAttentionTransformer Mini v2')

train_dataset = AutoregressiveLMDataset(next_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(wiki_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, gradient_accumulation_steps=gradient_acc_steps, use_moe_aux_loss=True, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

### DMA

In [None]:
# selected model: DMA
decoder = MoeAttentionTransformer.from_pretrained('ReactiveAI/DMAT-m')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 128
epochs = 1
gradient_acc_steps = 1

peak_lr = 3e-4 * gradient_acc_steps

subset_len = int(len(next_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(decoder.params_count())
print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
# warmup_steps = int(0.1 * total_steps)
warmup_steps = 0


logs_dir = './att_mini/tensorboard_logs/dma_rest'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./att_mini/dma_rest', push_to_hub=True,
                            hub_model_id='ReactiveAI/DMAT-m', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='DeepMoeAttentionTransformer Mini v2')

train_dataset = AutoregressiveLMDataset(next_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(wiki_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, gradient_accumulation_steps=gradient_acc_steps, use_moe_aux_loss=True, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

## SparseQueryAttention - Mini models test

> Tests for **SQA** variants with other params same as above Mini models

### Sparse Query Attention base

In [None]:
sqa_config = {
    **base_config,
    'att_type': 'sqa',
    'att_num_query_groups': 8
}


sqa_decoder = ExperimentalAttentionTransformer(**sqa_config)
sqa_decoder.params_count()

In [None]:
# selected model: SQA
decoder = sqa_decoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 128
epochs = 1
gradient_acc_steps = 1

peak_lr = 5e-4 * gradient_acc_steps

subset_len = int(len(wiki_train_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(decoder.params_count())
print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
# warmup_steps = int(0.1 * total_steps)
warmup_steps = int(0.25 * steps_per_epoch)


logs_dir = './att_mini/tensorboard_logs/sqa'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./att_mini/sqa', push_to_hub=True,
                            hub_model_id='ReactiveAI/SQAT-m', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='SQA Mini Ref Transformer v2')

train_dataset = AutoregressiveLMDataset(wiki_train_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(wiki_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, gradient_accumulation_steps=gradient_acc_steps, use_moe_aux_loss=False, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

### SparseQueryAttention - extreme version

In [None]:
xsqa_config = {
    **base_config,
    'att_type': 'sqa',
    'att_num_query_groups': 4
}


xsqa_decoder = ExperimentalAttentionTransformer(**xsqa_config)
xsqa_decoder.params_count()

In [None]:
# selected model: xSQA
decoder = xsqa_decoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 128
epochs = 1
gradient_acc_steps = 1

peak_lr = 5e-4 * gradient_acc_steps

subset_len = int(len(wiki_train_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(decoder.params_count())
print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
# warmup_steps = int(0.1 * total_steps)
warmup_steps = int(0.25 * steps_per_epoch)


logs_dir = './att_mini/tensorboard_logs/xsqa'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./att_mini/xsqa', push_to_hub=True,
                            hub_model_id='ReactiveAI/xSQAT-m', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='xSQA Mini Transformer v1')

train_dataset = AutoregressiveLMDataset(wiki_train_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(wiki_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, gradient_accumulation_steps=gradient_acc_steps, use_moe_aux_loss=False, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

### Symmetric SQA Variant

In [None]:
ssqa_config = {
    **base_config,
    'att_type': 'sqa',
    'att_num_query_groups': 8,
    'att_groups': 8,
}


ssqa_decoder = ExperimentalAttentionTransformer(**ssqa_config)
ssqa_decoder.params_count()

In [None]:
# selected model: sSQA
decoder = ssqa_decoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 128
epochs = 1
gradient_acc_steps = 1

peak_lr = 5e-4 * gradient_acc_steps

subset_len = int(len(wiki_train_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(decoder.params_count())
print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
# warmup_steps = int(0.1 * total_steps)
warmup_steps = int(0.25 * steps_per_epoch)


logs_dir = './att_mini/tensorboard_logs/ssqa'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./att_mini/ssqa', push_to_hub=True,
                            hub_model_id='ReactiveAI/sSQAT-m', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='sSQA Mini Transformer v1')

train_dataset = AutoregressiveLMDataset(wiki_train_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(wiki_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, gradient_accumulation_steps=gradient_acc_steps, use_moe_aux_loss=False, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

### Extreme Sparse Multi Query Attention

In [None]:
xsmqa_config = {
    **base_config,
    'att_type': 'sqa',
    'att_num_query_groups': 4,
    'att_groups': 1,
}


xsmqa_decoder = ExperimentalAttentionTransformer(**xsmqa_config)
xsmqa_decoder.params_count()

In [None]:
# selected model: xSMQA
decoder = xsmqa_decoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 128
epochs = 1
gradient_acc_steps = 1

peak_lr = 5e-4 * gradient_acc_steps

subset_len = int(len(wiki_train_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(decoder.params_count())
print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
# warmup_steps = int(0.1 * total_steps)
warmup_steps = int(0.25 * steps_per_epoch)


logs_dir = './att_mini/tensorboard_logs/xsmqa'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./att_mini/xsmqa', push_to_hub=True,
                            hub_model_id='ReactiveAI/xSMQAT-m', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='xSMQA Mini Transformer v1')

train_dataset = AutoregressiveLMDataset(wiki_train_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(wiki_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, gradient_accumulation_steps=gradient_acc_steps, use_moe_aux_loss=False, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

## Mini models computational performance comparison


- Nvidia L40S 48GB GPU
- 200 steps
- 32 batch size
```
Model:   (time per 200 steps, time per single batch)
'GQA':   (3.097075939178467, 0.015485379695892334)
'GQAs':  (3.107764959335327, 0.015538824796676636)
'SQA':   (2.827148914337158, 0.014135744571685791)
'xSQA':  (2.6865649223327637, 0.013432824611663818)
'xSMQA': (2.6817448139190674, 0.013408724069595337)
'sSQA':  (2.9016659259796143, 0.014508329629898072)
'MQA':   (2.898437023162842, 0.014492185115814208)
```
- 128 batch size
```
 Model:  (time per 200 steps, time per single batch)
'GQA':   (19.43833899497986, 0.0971916949748993)
'GQAs':  (19.273487091064453, 0.09636743545532227)
'SQA':   (16.972065210342407, 0.08486032605171204)
'xSQA':  (16.232131004333496, 0.08116065502166749)
'xSMQA': (15.909909963607788, 0.07954954981803894)
'sSQA':  (17.343858003616333, 0.08671929001808167)
'MQA':   (17.349792957305908, 0.08674896478652955)
```

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')
torch.cuda.empty_cache()
print(device)

inp = torch.ones(128, 1024, dtype=torch.int32).to(device)
gqa_decoder = gqa_decoder.to(device)
gqas_decoder = gqas_decoder.to(device)
sqa_decoder = sqa_decoder.to(device)
xsqa_decoder = xsqa_decoder.to(device)
xsmqa_decoder = xsmqa_decoder.to(device)
ssqa_decoder = ssqa_decoder.to(device)
mqa_decoder = mqa_decoder.to(device)

steps = 20
warmup = 100

results = {
    'GQA': (0.0, 0.0),
    'GQAs': (0.0, 0.0),
    'SQA': (0.0, 0.0),
    'xSQA': (0.0, 0.0),
    'xSMQA': (0.0, 0.0),
    'sSQA': (0.0, 0.0),
    'MQA': (0.0, 0.0),
}

with torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16):

  print(inp.dtype)
  for _ in range(warmup):
    _ = gqa_decoder(inp)

  t1 = datetime.timestamp(datetime.now())
  for _ in range(steps):
    _ = gqa_decoder(inp)
  t2 = datetime.timestamp(datetime.now())
  results['GQA'] = (t2 - t1, (t2 - t1) / steps)

  t1 = datetime.timestamp(datetime.now())
  for _ in range(steps):
    _ = gqas_decoder(inp)
  t2 = datetime.timestamp(datetime.now())
  results['GQAs'] = (t2 - t1, (t2 - t1) / steps)

  t1 = datetime.timestamp(datetime.now())
  for _ in range(steps):
    _ = sqa_decoder(inp)
  t2 = datetime.timestamp(datetime.now())
  results['SQA'] = (t2 - t1, (t2 - t1) / steps)

  t1 = datetime.timestamp(datetime.now())
  for _ in range(steps):
    _ = xsqa_decoder(inp)
  t2 = datetime.timestamp(datetime.now())
  results['xSQA'] = (t2 - t1, (t2 - t1) / steps)

  t1 = datetime.timestamp(datetime.now())
  for _ in range(steps):
    _ = xsmqa_decoder(inp)
  t2 = datetime.timestamp(datetime.now())
  results['xSMQA'] = (t2 - t1, (t2 - t1) / steps)

  t1 = datetime.timestamp(datetime.now())
  for _ in range(steps):
    _ = ssqa_decoder(inp)
  t2 = datetime.timestamp(datetime.now())
  results['sSQA'] = (t2 - t1, (t2 - t1) / steps)

  t1 = datetime.timestamp(datetime.now())
  for _ in range(steps):
    _ = ssqa_decoder(inp)
  t2 = datetime.timestamp(datetime.now())
  results['MQA'] = (t2 - t1, (t2 - t1) / steps)

results

# SQA - Micro MoE models full training test

> **SQA** variants tests for micro Mixture-of-Experts architectures, compared to **GQA** and **MQA**

## Init models

In [None]:
embed_dim = 128
vocab_size = 5_000
seq_len = 256

torch.random.manual_seed(2137)
np.random.seed(2137)
random.seed(2137)

base_config = {
    'num_layers': 6,
    'vocab_size': vocab_size,
    'embed_dim': embed_dim,
    'att_heads': 8,
    'att_groups': 2,
    'seq_len': seq_len,
    'use_flash_attention': False,
    'use_gated': True,
    'ff_dropout': 0.1,
    'ff_activation': 'silu',
    'ff_dim': 256,
    'use_moe_ff': True,
    'ff_num_experts': 12,
    'ff_moe_top_k': 2,
}

gqa_config = {
    **base_config,
    'att_type': 'gqa',
}

mqa_config = {
    **base_config,
    'att_type': 'mqa',
}


sqa_config = {
    **base_config,
    'att_type': 'sqa',
    'att_num_query_groups': 4
}

ssqa_config = {
    **base_config,
    'att_type': 'sqa',
    'att_num_query_groups': 4,
    'att_groups': 4,
}

xsqa_config = {
    **base_config,
    'att_type': 'sqa',
    'att_num_query_groups': 2,
}


gqa_decoder = ExperimentalAttentionTransformer(**gqa_config)
mqa_decoder = ExperimentalAttentionTransformer(**mqa_config)
sqa_decoder = ExperimentalAttentionTransformer(**sqa_config)
ssqa_decoder = ExperimentalAttentionTransformer(**ssqa_config)
xsqa_decoder = ExperimentalAttentionTransformer(**xsqa_config)

(('GQA', gqa_decoder.params_count()),
('MQA', mqa_decoder.params_count()),
('SQA', sqa_decoder.params_count()),
('sSQA', ssqa_decoder.params_count()),
('xSQA', xsqa_decoder.params_count()))

In [None]:
tr = TokenizerTrainer.from_pretrained('ReactiveAI/RxT-Alpha-Micro-Decoder')
tokenizer = tr.get_hf_tokenizer()
tokenizer

In [None]:
from datasets import load_dataset
stories_dataset = load_dataset('roneneldan/TinyStories', split='train', trust_remote_code=True)
hf_valid_dataset = load_dataset('roneneldan/TinyStories', split='validation', trust_remote_code=True)
len(stories_dataset), len(hf_valid_dataset)

## GQA

In [None]:
# selected model: GQA
decoder = gqa_decoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 256
epochs = 5
gradient_acc_steps = 1

peak_lr = 2e-3 * gradient_acc_steps

subset_len = int(len(stories_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
# warmup_steps = int(0.1 * total_steps)
warmup_steps = 0


logs_dir = './micro_att/tensorboard_logs/gqa'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./micro_att/gqa', push_to_hub=True,
                            hub_model_id='ReactiveAI/GQA-Ref-Micro', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='GQA Ref Micro Transformer v2')

train_dataset = AutoregressiveLMDataset(stories_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(hf_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, use_moe_aux_loss=True, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

## MQA

In [None]:
# selected model: MQA
decoder = mqa_decoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 256
epochs = 5
gradient_acc_steps = 1

peak_lr = 2e-3 * gradient_acc_steps

subset_len = int(len(stories_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
# warmup_steps = int(0.1 * total_steps)
warmup_steps = 0


logs_dir = './micro_att/tensorboard_logs/mqa'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./micro_att/mqa', push_to_hub=True,
                            hub_model_id='ReactiveAI/MQA-Ref-Micro', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='MQA Ref Micro Transformer v2')

train_dataset = AutoregressiveLMDataset(stories_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(hf_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, use_moe_aux_loss=True, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

## SQA

In [None]:
# selected model: SQA
decoder = sqa_decoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 256
epochs = 5
gradient_acc_steps = 1

peak_lr = 2e-3 * gradient_acc_steps

subset_len = int(len(stories_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
# warmup_steps = int(0.1 * total_steps)
warmup_steps = 0


logs_dir = './micro_att/tensorboard_logs/sqa'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./micro_att/sqa', push_to_hub=True,
                            hub_model_id='ReactiveAI/SQAT-mm', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='SQA-mm Transformer v1.0.0')

train_dataset = AutoregressiveLMDataset(stories_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(hf_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, use_moe_aux_loss=True, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

## sSQA

In [None]:
# selected model: sSQA
decoder = ssqa_decoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 256
epochs = 5
gradient_acc_steps = 1

peak_lr = 2e-3 * gradient_acc_steps

subset_len = int(len(stories_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
# warmup_steps = int(0.1 * total_steps)
warmup_steps = 0


logs_dir = './micro_att/tensorboard_logs/ssqa'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./micro_att/ssqa', push_to_hub=True,
                            hub_model_id='ReactiveAI/sSQAT-mm', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='sSQA-mm Transformer v1.0.0')

train_dataset = AutoregressiveLMDataset(stories_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(hf_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, use_moe_aux_loss=True, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

## xSQA

In [None]:
# selected model: xSQA
decoder = xsqa_decoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 256
epochs = 5
gradient_acc_steps = 1

peak_lr = 2e-3 * gradient_acc_steps

subset_len = int(len(stories_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
# warmup_steps = int(0.1 * total_steps)
warmup_steps = 0


logs_dir = './micro_att/tensorboard_logs/xsqa'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./micro_att/xsqa', push_to_hub=True,
                            hub_model_id='ReactiveAI/xSQAT-mm', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='xSQA-mm Transformer v1.0.0')

train_dataset = AutoregressiveLMDataset(stories_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(hf_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, use_moe_aux_loss=True, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

# Longer context computation efficiency

## Init models

In [None]:
torch.random.manual_seed(42)
np.random.seed(42)
random.seed(42)


def create_test_models(seq_len: int = 1024):
  embed_dim = 256
  vocab_size = 10_000

  base_config = {
      'num_layers': 8,
      'vocab_size': vocab_size,
      'embed_dim': embed_dim,
      'att_heads': 16,
      'att_groups': 4,
      'seq_len': seq_len,
      'use_flash_attention': False,
      'use_gated': True,
      'ff_dropout': 0.1,
      'ff_activation': 'silu',
      'ff_dim': 768,
  }

  gqa_config = {
      **base_config,
      'att_type': 'gqa',
  }

  mqa_config = {
      **base_config,
      'att_type': 'mqa',
  }

  mha_config = {
      **base_config,
      'att_type': 'mha',
  }

  sqa_config = {
      **base_config,
      'att_type': 'sqa',
      'att_num_query_groups': 8
  }

  ssqa_config = {
      **base_config,
      'att_type': 'sqa',
      'att_num_query_groups': 8,
      'att_groups': 8,
  }

  xsqa_config = {
      **base_config,
      'att_type': 'sqa',
      'att_num_query_groups': 4
  }

  gqa_decoder = ExperimentalAttentionTransformer(**gqa_config)
  mqa_decoder = ExperimentalAttentionTransformer(**mqa_config)
  mha_decoder = ExperimentalAttentionTransformer(**mha_config)
  sqa_decoder = ExperimentalAttentionTransformer(**sqa_config)
  ssqa_decoder = ExperimentalAttentionTransformer(**ssqa_config)
  xsqa_decoder = ExperimentalAttentionTransformer(**xsqa_config)

  print((
    ('GQA', gqa_decoder.params_count()),
    ('MQA', mqa_decoder.params_count()),
    ('MHA', mha_decoder.params_count()),
    ('SQA', sqa_decoder.params_count()),
    ('sSQA', ssqa_decoder.params_count()),
    ('xSQA', xsqa_decoder.params_count())
  ))

  return {
      'GQA': gqa_decoder,
      'MQA': mqa_decoder,
      'MHA': mha_decoder,
      'SQA': sqa_decoder,
      'sSQA': ssqa_decoder,
      'xSQA': xsqa_decoder,
  }

In [None]:
def time_tests(models: dict[str, ExperimentalAttentionTransformer], batch_size: int = 64, seq_len: int = 1024):
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  if torch.cuda.is_available():
    torch.cuda.empty_cache()

  print(device)

  inp = torch.randint(1, 9999, (batch_size, seq_len), dtype=torch.int32).to(device)
  gqa_decoder = models['GQA'].to(device)
  mqa_decoder = models['MQA'].to(device)
  mha_decoder = models['MHA'].to(device)
  sqa_decoder = models['SQA'].to(device)
  ssqa_decoder = models['sSQA'].to(device)
  xsqa_decoder = models['xSQA'].to(device)


  steps = 50
  warmup = 100

  results = {
      'GQA': (0.0, 0.0),
      'MQA': (0.0, 0.0),
      'MHA': (0.0, 0.0),
      'SQA': (0.0, 0.0),
      'sSQA': (0.0, 0.0),
      'xSQA': (0.0, 0.0),
  }

  with torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16):

    print(inp.dtype)
    for _ in range(warmup):
      _ = gqa_decoder(inp)

    t1 = datetime.timestamp(datetime.now())
    for _ in range(steps):
      _ = gqa_decoder(inp)
    t2 = datetime.timestamp(datetime.now())
    results['GQA'] = (t2 - t1, (t2 - t1) / steps)

    t1 = datetime.timestamp(datetime.now())
    for _ in range(steps):
      _ = mqa_decoder(inp)
    t2 = datetime.timestamp(datetime.now())
    results['MQA'] = (t2 - t1, (t2 - t1) / steps)

    t1 = datetime.timestamp(datetime.now())
    for _ in range(steps):
      _ = mha_decoder(inp)
    t2 = datetime.timestamp(datetime.now())
    results['MHA'] = (t2 - t1, (t2 - t1) / steps)

    t1 = datetime.timestamp(datetime.now())
    for _ in range(steps):
      _ = sqa_decoder(inp)
    t2 = datetime.timestamp(datetime.now())
    results['SQA'] = (t2 - t1, (t2 - t1) / steps)

    t1 = datetime.timestamp(datetime.now())
    for _ in range(steps):
      _ = ssqa_decoder(inp)
    t2 = datetime.timestamp(datetime.now())
    results['sSQA'] = (t2 - t1, (t2 - t1) / steps)

    t1 = datetime.timestamp(datetime.now())
    for _ in range(steps):
      _ = xsqa_decoder(inp)
    t2 = datetime.timestamp(datetime.now())
    results['xSQA'] = (t2 - t1, (t2 - t1) / steps)

  return results

#### 1024 Sequence / 128 batch size / 50 steps

In [None]:
time_tests(create_test_models(1024), 128, 1024)

#### 4096 Sequence / 32 batch size / 50 steps

In [None]:
time_tests(create_test_models(4096), 32, 4096)

#### 32768 Sequence / 4 batch size / 50 steps

In [None]:
time_tests(create_test_models(32768), 4, 32768)

#### 131072 Sequence / 1 batch size / 50 steps

In [None]:
time_tests(create_test_models(131072), 1, 131072)