<a href="https://colab.research.google.com/github/RxAI-dev/rxnn-notebooks/blob/main/Experimental_Attention_Check.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Install RxNN and dependencies

In [None]:
!pip install rxnn==0.1.59 torch==2.6.0 transformers tokenizers huggingface_hub datasets

Collecting rxnn==0.1.59
  Downloading rxnn-0.1.59-py3-none-any.whl.metadata (16 kB)
Collecting datasets
  Downloading datasets-3.6.0-py3-none-any.whl.metadata (19 kB)
Collecting flash-attention<2.0.0,>=1.0.0 (from rxnn==0.1.59)
  Downloading flash_attention-1.0.0-py3-none-any.whl.metadata (274 bytes)
Collecting huggingface_hub
  Downloading huggingface_hub-0.30.2-py3-none-any.whl.metadata (13 kB)
Collecting tensorboard<3.0.0,>=2.19.0 (from rxnn==0.1.59)
  Downloading tensorboard-2.19.0-py3-none-any.whl.metadata (1.8 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch==2.6.0)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch==2.6.0)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch==2.6.0)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.me

In [None]:
!wget https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl

--2025-05-29 17:39:44--  https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl
Resolving github.com (github.com)... 140.82.113.3
Connecting to github.com (github.com)|140.82.113.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/494232964/e8020a46-5dad-4306-b1de-42a54fc0813b?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=releaseassetproduction%2F20250529%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250529T173945Z&X-Amz-Expires=300&X-Amz-Signature=129ad4441ad98258820de5ea8c3eb00d3c21239593ffd0448db9b23d1b7cb156&X-Amz-SignedHeaders=host&response-content-disposition=attachment%3B%20filename%3Dflash_attn-2.7.4.post1%2Bcu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl&response-content-type=application%2Foctet-stream [following]
--2025-05-29 17:39:45--  https://objects.githubuserc

In [None]:
!pip install --no-dependencies --upgrade flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl

Processing ./flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl
Installing collected packages: flash-attn
Successfully installed flash-attn-2.7.4.post1


## Import libraries

In [None]:
import torch, numpy as np, random, gc, os
from rxnn.experimental.models import ExperimentalAttentionTransformer
from rxnn.training.dataset import AutoregressiveLMDataset
from rxnn.training.bml import AutoregressiveTrainer
from rxnn.training.scheduler import get_transformer_lr_scheduler
from rxnn.training.callbacks import PrintLossCallback, PrintAccuracyCallback, TokenCounterCallback, ModelSaveCallback
from rxnn.training.tokenizer import TokenizerTrainer, load_tokenizer_from_hf_hub
from rxnn.utils import get_model_size
from datetime import datetime

In [None]:
import torch
torch.__version__

'2.6.0+cu124'

In [None]:
import flash_attn
flash_attn.__version__

'2.7.4.post1'

## Mixture-of-Experts Attention - Test on micro-size models (~2.5M Params) with tinyStories dataset

> Unfortunately Mixture-of-Experts Attention has worse results than classic GQA/MQA, so we abandoned this study and proceed to **Sparse Query Attention** research.



In [None]:
embed_dim = 128
vocab_size = 5_000
seq_len = 256

torch.random.manual_seed(2137)
np.random.seed(2137)
random.seed(2137)

base_config = {
    'num_layers': 6,
    'vocab_size': vocab_size,
    'embed_dim': embed_dim,
    'att_heads': 8,
    'att_groups': 2,
    'seq_len': seq_len,
    'use_flash_attention': False,
    'use_gated': True,
    'ff_dropout': 0.1,
    'ff_activation': 'silu',
    'ff_dim': 384,
}

gqa_config = {
    **base_config,
    'att_type': 'gqa',
}

mqa_config = {
    **base_config,
    'att_type': 'mqa',
}

mha_config = {
    **base_config,
    'att_type': 'mha',
}


gma_config = {
    **base_config,
    'att_type': 'gma',
}

dma_config = {
    **base_config,
    'att_type': 'dma',
    'att_num_query_groups': 4
}


gqa_decoder = ExperimentalAttentionTransformer(**gqa_config)
mqa_decoder = ExperimentalAttentionTransformer(**mqa_config)
mha_decoder = ExperimentalAttentionTransformer(**mha_config)
gma_decoder = ExperimentalAttentionTransformer(**gma_config)
dma_decoder = ExperimentalAttentionTransformer(**dma_config)

(('GQA', gqa_decoder.params_count()),
('MQA', mqa_decoder.params_count()),
('MHA', mha_decoder.params_count()),
('GMA', gma_decoder.params_count()),
('DMA', dma_decoder.params_count()))

(('GQA', 'Model params 2.42M'),
 ('MQA', 'Model params 2.4M'),
 ('MHA', 'Model params 2.57M'),
 ('GMA', 'Model params 2.58M'),
 ('DMA', 'Model params 2.53M'))

### Computational efficiency

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.cuda.empty_cache()
print(device)

inp = torch.ones(256, 256, dtype=torch.int32).to(device)
gqa_decoder = gqa_decoder.to(device)
mha_decoder = mha_decoder.to(device)
mqa_decoder = mqa_decoder.to(device)
dma_decoder = dma_decoder.to(device)
gma_decoder = gma_decoder.to(device)

steps = 100
warmup = 100
with torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16):

  print(inp.dtype)
  for _ in range(warmup):
    _ = dma_decoder(inp)

  t1 = datetime.timestamp(datetime.now())
  for _ in range(steps):
    _ = dma_decoder(inp)
  t2 = datetime.timestamp(datetime.now())
  print('DMA: ', t2 - t1, (t2 - t1) / steps)

  t1 = datetime.timestamp(datetime.now())
  for _ in range(steps):
    _ = gma_decoder(inp)
  t2 = datetime.timestamp(datetime.now())
  print('GMA: ', t2 - t1, (t2 - t1) / steps)

  t1 = datetime.timestamp(datetime.now())
  for _ in range(steps):
    _ = gqa_decoder(inp)
  t2 = datetime.timestamp(datetime.now())
  print('GQA: ', t2 - t1, (t2 - t1) / steps)

  t1 = datetime.timestamp(datetime.now())
  for _ in range(steps):
      _ = mha_decoder(inp)
  t2 = datetime.timestamp(datetime.now())
  print('MHA: ', t2 - t1, (t2 - t1) / steps)

  t1 = datetime.timestamp(datetime.now())
  for _ in range(steps):
    _ = mqa_decoder(inp)
  t2 = datetime.timestamp(datetime.now())
  print('MQA: ', t2 - t1, (t2 - t1) / steps)

cuda
torch.int32


KeyboardInterrupt: 

### Get tokenizer

In [None]:
tr = TokenizerTrainer.from_pretrained('ReactiveAI/RxT-Alpha-Micro-Decoder')
tokenizer = tr.get_hf_tokenizer()
tokenizer


PreTrainedTokenizerFast(name_or_path='', vocab_size=5000, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("[BOS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	3: AddedToken("[EOS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	4: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	5000: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	5001: AddedTok

### Load datasets

In [None]:
from datasets import load_dataset
stories_dataset = load_dataset('roneneldan/TinyStories', split='train', trust_remote_code=True)
hf_valid_dataset = load_dataset('roneneldan/TinyStories', split='validation', trust_remote_code=True)
len(stories_dataset), len(hf_valid_dataset)

(2119719, 21990)

### Train MultiHeadAttention model

In [None]:
# selected model: MHA
decoder = mha_decoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 256
epochs = 1
gradient_acc_steps = 1

peak_lr = 5e-4 * gradient_acc_steps

subset_len = int(len(stories_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
warmup_steps = int(0.25 * steps_per_epoch)


logs_dir = './att/tensorboard_logs/mha'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./att/mha', push_to_hub=True,
                            hub_model_id='ReactiveAI/MHA-MAT', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='MHA Ref Transformer v2')

train_dataset = AutoregressiveLMDataset(stories_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(hf_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, gradient_accumulation_steps=gradient_acc_steps, use_moe_aux_loss=False, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

cuda
Total steps per epoch: 8279
Start epoch: 0
Batch 100 / 8279 - loss: 7.71274995803833, last 100 batches mean loss: 8.3768
Total processed tokens: 5.57M
Batch 200 / 8279 - loss: 5.271077632904053, last 100 batches mean loss: 6.4205
Total processed tokens: 11.1M
Batch 300 / 8279 - loss: 4.4004316329956055, last 100 batches mean loss: 4.7176
Total processed tokens: 16.6M
Batch 400 / 8279 - loss: 3.9028382301330566, last 100 batches mean loss: 4.1434
Total processed tokens: 22.1M
Batch 500 / 8279 - loss: 3.55961012840271, last 100 batches mean loss: 3.6851
Total processed tokens: 27.6M
Batch 600 / 8279 - loss: 3.1952333450317383, last 100 batches mean loss: 3.3181
Total processed tokens: 33.1M
Batch 700 / 8279 - loss: 2.9454121589660645, last 100 batches mean loss: 3.0656
Total processed tokens: 38.6M
Batch 800 / 8279 - loss: 2.843296527862549, last 100 batches mean loss: 2.8905
Total processed tokens: 44.1M
Batch 900 / 8279 - loss: 2.6596810817718506, last 100 batches mean loss: 2.764

model.safetensors:   0%|          | 0.00/10.3M [00:00<?, ?B/s]

Epoch 0 - mean loss: 2.0982
Finished training! All losses:
[np.float64(2.098245295839033)]
Total training tokens: 456M
Final model saved to ./att/mha/final_model


model.safetensors:   0%|          | 0.00/10.3M [00:00<?, ?B/s]

Model uploaded to repo: ReactiveAI/MHA-MAT


'Total time: 29.48, Time per step/batch: 0.2136'

### Train GroupedQueryAttention model

In [None]:
# selected model: GQA
decoder = gqa_decoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 256
epochs = 1
gradient_acc_steps = 1

peak_lr = 5e-4 * gradient_acc_steps

subset_len = int(len(stories_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
# warmup_steps = int(0.1 * total_steps)
warmup_steps = int(0.25 * steps_per_epoch)


logs_dir = './att/tensorboard_logs/gqa'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./att/gqa', push_to_hub=True,
                            hub_model_id='ReactiveAI/GQA-MAT', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='GQA Ref Transformer v2')

train_dataset = AutoregressiveLMDataset(stories_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(hf_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, gradient_accumulation_steps=gradient_acc_steps, use_moe_aux_loss=False, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

cuda
Total steps per epoch: 8279
Start epoch: 0
Batch 100 / 8279 - loss: 7.754459857940674, last 100 batches mean loss: 8.3664
Total processed tokens: 5.56M
Batch 200 / 8279 - loss: 5.195274353027344, last 100 batches mean loss: 6.4757
Total processed tokens: 11.1M
Batch 300 / 8279 - loss: 4.464555740356445, last 100 batches mean loss: 4.7448
Total processed tokens: 16.6M
Batch 400 / 8279 - loss: 3.9101648330688477, last 100 batches mean loss: 4.1402
Total processed tokens: 22.1M
Batch 500 / 8279 - loss: 3.4745724201202393, last 100 batches mean loss: 3.6851
Total processed tokens: 27.6M
Batch 600 / 8279 - loss: 3.1606638431549072, last 100 batches mean loss: 3.3237
Total processed tokens: 33.1M
Batch 700 / 8279 - loss: 2.9895577430725098, last 100 batches mean loss: 3.0948
Total processed tokens: 38.6M
Batch 800 / 8279 - loss: 2.8491697311401367, last 100 batches mean loss: 2.9393
Total processed tokens: 44.1M
Batch 900 / 8279 - loss: 2.6918606758117676, last 100 batches mean loss: 2.

model.safetensors:   0%|          | 0.00/9.70M [00:00<?, ?B/s]

Epoch 0 - mean loss: 2.1244
Finished training! All losses:
[np.float64(2.124399639500512)]
Total training tokens: 456M
Final model saved to ./att/gqa/final_model


model.safetensors:   0%|          | 0.00/9.70M [00:00<?, ?B/s]

Model uploaded to repo: ReactiveAI/GQA-MAT


'Total time: 28.79, Time per step/batch: 0.2086'

### Train MultiQueryAttention model

In [None]:
# selected model: MQA
decoder = mqa_decoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 256
epochs = 1
gradient_acc_steps = 1

peak_lr = 5e-4 * gradient_acc_steps

subset_len = int(len(stories_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
# warmup_steps = int(0.1 * total_steps)
warmup_steps = int(0.25 * steps_per_epoch)


logs_dir = './att/tensorboard_logs/mqa'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./att/mqa', push_to_hub=True,
                            hub_model_id='ReactiveAI/MQA-MAT', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='MQA Ref Transformer v2')

train_dataset = AutoregressiveLMDataset(stories_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(hf_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, gradient_accumulation_steps=gradient_acc_steps, use_moe_aux_loss=False, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

cuda
Total steps per epoch: 8279
Start epoch: 0
Batch 100 / 8279 - loss: 7.861738681793213, last 100 batches mean loss: 8.5077
Total processed tokens: 5.56M
Batch 200 / 8279 - loss: 5.239120006561279, last 100 batches mean loss: 6.5193
Total processed tokens: 11.1M
Batch 300 / 8279 - loss: 4.3553972244262695, last 100 batches mean loss: 4.7157
Total processed tokens: 16.6M
Batch 400 / 8279 - loss: 3.965428590774536, last 100 batches mean loss: 4.1305
Total processed tokens: 22.1M
Batch 500 / 8279 - loss: 3.476653575897217, last 100 batches mean loss: 3.6827
Total processed tokens: 27.6M
Batch 600 / 8279 - loss: 3.2227752208709717, last 100 batches mean loss: 3.3301
Total processed tokens: 33.1M
Batch 700 / 8279 - loss: 3.0403566360473633, last 100 batches mean loss: 3.0853
Total processed tokens: 38.6M
Batch 800 / 8279 - loss: 2.8099117279052734, last 100 batches mean loss: 2.9071
Total processed tokens: 44.1M
Batch 900 / 8279 - loss: 2.768620252609253, last 100 batches mean loss: 2.78

model.safetensors:   0%|          | 0.00/9.60M [00:00<?, ?B/s]

Epoch 0 - mean loss: 2.1500
Finished training! All losses:
[np.float64(2.1500039593033167)]
Total training tokens: 456M
Final model saved to ./att/mqa/final_model


model.safetensors:   0%|          | 0.00/9.60M [00:00<?, ?B/s]

Model uploaded to repo: ReactiveAI/MQA-MAT


'Total time: 28.92, Time per step/batch: 0.2096'

### Train GroupedMoeAttention model

In [None]:
# selected model: GMA
decoder = gma_decoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 256
epochs = 1
gradient_acc_steps = 1

peak_lr = 5e-4 * gradient_acc_steps

subset_len = int(len(stories_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
# warmup_steps = int(0.1 * total_steps)
warmup_steps = int(0.25 * steps_per_epoch)


logs_dir = './att/tensorboard_logs/gma'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./att/gma', push_to_hub=True,
                            hub_model_id='ReactiveAI/GMAT', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='GroupedMoeAttentionTransformer v2')

train_dataset = AutoregressiveLMDataset(stories_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(hf_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, gradient_accumulation_steps=gradient_acc_steps, use_moe_aux_loss=True, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

cuda
Total steps per epoch: 8279
Start epoch: 0
Batch 100 / 8279 - loss: 8.029437065124512, last 100 batches mean loss: 8.6079
Total processed tokens: 5.56M
Batch 200 / 8279 - loss: 5.754378318786621, last 100 batches mean loss: 6.9400
Total processed tokens: 11.1M
Batch 300 / 8279 - loss: 4.479120254516602, last 100 batches mean loss: 4.9021
Total processed tokens: 16.6M
Batch 400 / 8279 - loss: 4.021343231201172, last 100 batches mean loss: 4.2052
Total processed tokens: 22.1M
Batch 500 / 8279 - loss: 3.5334341526031494, last 100 batches mean loss: 3.7333
Total processed tokens: 27.6M
Batch 600 / 8279 - loss: 3.2603158950805664, last 100 batches mean loss: 3.3810
Total processed tokens: 33.1M
Batch 700 / 8279 - loss: 3.021362066268921, last 100 batches mean loss: 3.1520
Total processed tokens: 38.6M
Batch 800 / 8279 - loss: 2.9401540756225586, last 100 batches mean loss: 2.9890
Total processed tokens: 44.1M
Batch 900 / 8279 - loss: 2.828660249710083, last 100 batches mean loss: 2.872

model.safetensors:   0%|          | 0.00/10.3M [00:00<?, ?B/s]

Epoch 0 - mean loss: 2.2083
Finished training! All losses:
[np.float64(2.208270316409028)]
Total training tokens: 456M
Final model saved to ./att/gma/final_model


model.safetensors:   0%|          | 0.00/10.3M [00:00<?, ?B/s]

Model uploaded to repo: ReactiveAI/GMAT


'Total time: 29.64, Time per step/batch: 0.2148'

### Train DeepMoeAttention model

In [None]:
# selected model: DMA
decoder = dma_decoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 256
epochs = 1
gradient_acc_steps = 1

peak_lr = 5e-4 * gradient_acc_steps

subset_len = int(len(stories_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
warmup_steps = int(0.25 * steps_per_epoch)


logs_dir = './att/tensorboard_logs/dma'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./att/dma', push_to_hub=True,
                            hub_model_id='ReactiveAI/DMAT', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='DeepMoeAttentionTransformer v1')

train_dataset = AutoregressiveLMDataset(stories_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(hf_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, gradient_accumulation_steps=gradient_acc_steps, use_moe_aux_loss=True, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

cuda
Total steps per epoch: 8279
Start epoch: 0
Batch 100 / 8279 - loss: 7.906749248504639, last 100 batches mean loss: 8.5336
Total processed tokens: 5.56M
Batch 200 / 8279 - loss: 5.896146297454834, last 100 batches mean loss: 6.9306
Total processed tokens: 11.1M
Batch 300 / 8279 - loss: 4.4497528076171875, last 100 batches mean loss: 4.9165
Total processed tokens: 16.6M
Batch 400 / 8279 - loss: 3.9608097076416016, last 100 batches mean loss: 4.2008
Total processed tokens: 22.1M
Batch 500 / 8279 - loss: 3.6384663581848145, last 100 batches mean loss: 3.7688
Total processed tokens: 27.6M
Batch 600 / 8279 - loss: 3.3330605030059814, last 100 batches mean loss: 3.4389
Total processed tokens: 33.1M
Batch 700 / 8279 - loss: 3.169036388397217, last 100 batches mean loss: 3.2048
Total processed tokens: 38.6M
Batch 800 / 8279 - loss: 3.014430284500122, last 100 batches mean loss: 3.0451
Total processed tokens: 44.1M
Batch 900 / 8279 - loss: 2.9406826496124268, last 100 batches mean loss: 2.9

model.safetensors:   0%|          | 0.00/10.1M [00:00<?, ?B/s]

Epoch 0 - mean loss: 2.2479
Finished training! All losses:
[np.float64(2.2479153283264326)]
Total training tokens: 456M
Final model saved to ./att/dma/final_model


model.safetensors:   0%|          | 0.00/10.1M [00:00<?, ?B/s]

Model uploaded to repo: ReactiveAI/DMAT


'Total time: 29.86, Time per step/batch: 0.2164'

### Train Extended DMA model

In [None]:
# selected model: Extended DMA
ext_config = {
    'num_layers': 6,
    'vocab_size': vocab_size,
    'embed_dim': embed_dim,
    'att_heads': 8,
    'att_groups': 8,
    'seq_len': seq_len,
    'use_flash_attention': False,
    'use_gated': True,
    'ff_dropout': 0.1,
    'ff_activation': 'silu',
    'ff_dim': 384,
    'att_type': 'dma',
    'att_num_experts': 16,
    'att_num_query_experts': 16,
    'att_num_query_groups': 8
}
decoder = ExperimentalAttentionTransformer(**ext_config)
print(decoder.params_count())
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 256
epochs = 1
gradient_acc_steps = 1

peak_lr = 5e-4 * gradient_acc_steps

subset_len = int(len(stories_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
warmup_steps = int(0.25 * steps_per_epoch)


logs_dir = './att/tensorboard_logs/xdma'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./att/xdma', push_to_hub=True,
                            hub_model_id='ReactiveAI/xDMAT', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='ExtendedDeepMoeAttentionTransformer v1')

train_dataset = AutoregressiveLMDataset(stories_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(hf_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, gradient_accumulation_steps=gradient_acc_steps, use_moe_aux_loss=True, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

Model params 2.89M
cuda
Total steps per epoch: 8279
Start epoch: 0
Batch 100 / 8279 - loss: 7.888348579406738, last 100 batches mean loss: 8.3611
Total processed tokens: 5.56M
Batch 200 / 8279 - loss: 6.376858711242676, last 100 batches mean loss: 7.0265
Total processed tokens: 11.1M
Batch 300 / 8279 - loss: 4.502211093902588, last 100 batches mean loss: 5.1582
Total processed tokens: 16.6M
Batch 400 / 8279 - loss: 4.014655590057373, last 100 batches mean loss: 4.1986
Total processed tokens: 22.1M
Batch 500 / 8279 - loss: 3.5566840171813965, last 100 batches mean loss: 3.7417
Total processed tokens: 27.6M
Batch 600 / 8279 - loss: 3.28279972076416, last 100 batches mean loss: 3.4235
Total processed tokens: 33.1M
Batch 700 / 8279 - loss: 3.1950063705444336, last 100 batches mean loss: 3.2181
Total processed tokens: 38.6M
Batch 800 / 8279 - loss: 3.0198488235473633, last 100 batches mean loss: 3.0811
Total processed tokens: 44.1M
Batch 900 / 8279 - loss: 2.8880774974823, last 100 batches 

model.safetensors:   0%|          | 0.00/11.6M [00:00<?, ?B/s]

Epoch 0 - mean loss: 2.2768
Finished training! All losses:
[np.float64(2.2767607357766892)]
Total training tokens: 456M
Final model saved to ./att/xdma/final_model


model.safetensors:   0%|          | 0.00/11.6M [00:00<?, ?B/s]

Model uploaded to repo: ReactiveAI/xDMAT


'Total time: 32.07, Time per step/batch: 0.2324'

## Mixture-of-Experts Attention - Test on mini-size models (~11-12M Params) on wikipedia dataset

> For faster results, use only 50% of wikipedia dataset (>3M items) - 45% for training and 5% for validation.

> Unfortunately Mixture-of-Experts Attention has worse results than classic GQA/MQA, so we abandoned this study and proceed to **Sparse Query Attention** research.

> **MHA/GQA/MQA** results from this run were later used to compare with **SQA** variants

In [None]:
embed_dim = 256
vocab_size = 10_000
seq_len = 1024

torch.random.manual_seed(210037)
np.random.seed(210037)
random.seed(210037)

base_config = {
    'num_layers': 8,
    'vocab_size': vocab_size,
    'embed_dim': embed_dim,
    'att_heads': 16,
    'att_groups': 4,
    'seq_len': seq_len,
    'use_flash_attention': False,
    'use_gated': True,
    'ff_dropout': 0.1,
    'ff_activation': 'silu',
    'ff_dim': 768,
}

gqa_config = {
    **base_config,
    'att_type': 'gqa',
}

mqa_config = {
    **base_config,
    'att_type': 'mqa',
}

mha_config = {
    **base_config,
    'att_type': 'mha',
}


gma_config = {
    **base_config,
    'att_type': 'gma',
}

dma_config = {
    **base_config,
    'att_type': 'dma',
    'att_num_query_groups': 8
}


gqa_decoder = ExperimentalAttentionTransformer(**gqa_config)
mqa_decoder = ExperimentalAttentionTransformer(**mqa_config)
mha_decoder = ExperimentalAttentionTransformer(**mha_config)
gma_decoder = ExperimentalAttentionTransformer(**gma_config)
dma_decoder = ExperimentalAttentionTransformer(**dma_config)

(('GQA', gqa_decoder.params_count()),
('MQA', mqa_decoder.params_count()),
('MHA', mha_decoder.params_count()),
('GMA', gma_decoder.params_count()),
('DMA', dma_decoder.params_count()))

(('GQA', 'Model params 11.2M'),
 ('MQA', 'Model params 11M'),
 ('MHA', 'Model params 12M'),
 ('GMA', 'Model params 12M'),
 ('DMA', 'Model params 11.8M'))

#### Use pretrained tokenizer from RxT-Alpha Mini

In [None]:
tokenizer = load_tokenizer_from_hf_hub('ReactiveAI/RxT-Alpha-Mini-Decoder')
tokenizer

tokenizer.json:   0%|          | 0.00/670k [00:00<?, ?B/s]

### Load and split dataset

In [None]:
from datasets import load_dataset
wiki_dataset = load_dataset('wikimedia/wikipedia', '20231101.en', split='train', trust_remote_code=True)
train_split_len = int(len(wiki_dataset) * 0.45)
valid_split_len = int(len(wiki_dataset) * 0.05)
wiki_train_dataset = wiki_dataset.select(range(train_split_len))
wiki_valid_dataset = wiki_dataset.select(range(train_split_len, train_split_len + valid_split_len))
_rest_dataset = wiki_dataset.select(range(train_split_len + valid_split_len, len(wiki_dataset)))
len(wiki_train_dataset), len(wiki_valid_dataset), len(_rest_dataset)

Resolving data files:   0%|          | 0/41 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/41 [00:00<?, ?it/s]

(2883516, 320390, 3203908)

### Train GQA Mini model




In [None]:
# selected model: GQA
decoder = gqa_decoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 128
epochs = 1
gradient_acc_steps = 1

peak_lr = 5e-4 * gradient_acc_steps

subset_len = int(len(wiki_train_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(decoder.params_count())
print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
# warmup_steps = int(0.1 * total_steps)
warmup_steps = int(0.25 * steps_per_epoch)


logs_dir = './att_mini/tensorboard_logs/gqa'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./att_mini/gqa', push_to_hub=True,
                            hub_model_id='ReactiveAI/GQA-MAT-m', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='GQA Mini Ref Transformer v')

train_dataset = AutoregressiveLMDataset(wiki_train_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(wiki_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, gradient_accumulation_steps=gradient_acc_steps, use_moe_aux_loss=False, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

cuda
Model params 11.2M
Total steps per epoch: 22526
Start epoch: 0
Batch 100 / 22526 - loss: 5.277674198150635, last 100 batches mean loss: 7.9230
Total processed tokens: 6.66M
Batch 200 / 22526 - loss: 4.061792373657227, last 100 batches mean loss: 4.5532
Total processed tokens: 13.3M
Batch 300 / 22526 - loss: 3.6583282947540283, last 100 batches mean loss: 3.7906
Total processed tokens: 19.9M
Batch 400 / 22526 - loss: 3.256650686264038, last 100 batches mean loss: 3.5088
Total processed tokens: 26.5M
Batch 500 / 22526 - loss: 3.601026773452759, last 100 batches mean loss: 3.3972
Total processed tokens: 33.1M
Batch 600 / 22526 - loss: 3.207859992980957, last 100 batches mean loss: 3.2777
Total processed tokens: 39.7M
Batch 700 / 22526 - loss: 3.0508265495300293, last 100 batches mean loss: 3.1647
Total processed tokens: 46.3M
Batch 800 / 22526 - loss: 2.9742844104766846, last 100 batches mean loss: 3.0276
Total processed tokens: 53M
Batch 900 / 22526 - loss: 3.0777719020843506, last 

model.safetensors:   0%|          | 0.00/44.7M [00:00<?, ?B/s]

Epoch 0 - mean loss: 1.8064
Finished training! All losses:
[np.float64(1.8063621737009021)]
Total training tokens: 1.49B
Final model saved to ./att_mini/gqa/final_model


model.safetensors:   0%|          | 0.00/44.7M [00:00<?, ?B/s]

Model uploaded to repo: ReactiveAI/GQA-MAT-m


'Total time: 258.19, Time per step/batch: 0.6877'

### Train MHA Mini model

In [None]:
# selected model: MHA
decoder = mha_decoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 128
epochs = 1
gradient_acc_steps = 1

peak_lr = 5e-4 * gradient_acc_steps

subset_len = int(len(wiki_train_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(decoder.params_count())
print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
# warmup_steps = int(0.1 * total_steps)
warmup_steps = int(0.25 * steps_per_epoch)


logs_dir = './att_mini/tensorboard_logs/mha'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./att_mini/mha', push_to_hub=True,
                            hub_model_id='ReactiveAI/MHA-MAT-m', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='MHA Mini Ref Transformer v1')

train_dataset = AutoregressiveLMDataset(wiki_train_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(wiki_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, gradient_accumulation_steps=gradient_acc_steps, use_moe_aux_loss=False, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

cuda
Model params 12M
Total steps per epoch: 22526
Start epoch: 0
Batch 100 / 22526 - loss: 5.140275478363037, last 100 batches mean loss: 8.2050
Total processed tokens: 6.67M
Batch 200 / 22526 - loss: 3.771393299102783, last 100 batches mean loss: 4.4944
Total processed tokens: 13.2M
Batch 300 / 22526 - loss: 3.2529797554016113, last 100 batches mean loss: 3.7546
Total processed tokens: 19.9M
Batch 400 / 22526 - loss: 3.5089826583862305, last 100 batches mean loss: 3.4915
Total processed tokens: 26.4M
Batch 500 / 22526 - loss: 3.30502986907959, last 100 batches mean loss: 3.4272
Total processed tokens: 33M
Batch 600 / 22526 - loss: 3.229665517807007, last 100 batches mean loss: 3.2912
Total processed tokens: 39.7M
Batch 700 / 22526 - loss: 3.004945993423462, last 100 batches mean loss: 3.1047
Total processed tokens: 46.2M
Batch 800 / 22526 - loss: 2.7601706981658936, last 100 batches mean loss: 2.9983
Total processed tokens: 52.9M
Batch 900 / 22526 - loss: 3.108818531036377, last 100 

model.safetensors:   0%|          | 0.00/47.9M [00:00<?, ?B/s]

Epoch 0 - mean loss: 1.7771
Finished training! All losses:
[np.float64(1.777103994463871)]
Total training tokens: 1.49B
Final model saved to ./att_mini/mha/final_model


model.safetensors:   0%|          | 0.00/47.9M [00:00<?, ?B/s]

Model uploaded to repo: ReactiveAI/MHA-MAT-m


'Total time: 269.32, Time per step/batch: 0.7173'

### Train MQA Mini model

In [None]:
# selected model: MQA
decoder = mqa_decoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 128
epochs = 1
gradient_acc_steps = 1

peak_lr = 5e-4 * gradient_acc_steps

subset_len = int(len(wiki_train_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(decoder.params_count())
print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
# warmup_steps = int(0.1 * total_steps)
warmup_steps = int(0.25 * steps_per_epoch)


logs_dir = './att_mini/tensorboard_logs/mqa'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./att_mini/mqa', push_to_hub=True,
                            hub_model_id='ReactiveAI/MQA-MAT-m', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='MQA Mini Ref Transformer v1')

train_dataset = AutoregressiveLMDataset(wiki_train_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(wiki_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, gradient_accumulation_steps=gradient_acc_steps, use_moe_aux_loss=False, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

cuda
Model params 11M
Total steps per epoch: 22526
Start epoch: 0
Batch 100 / 22526 - loss: 5.595656871795654, last 100 batches mean loss: 8.0259
Total processed tokens: 6.73M
Batch 200 / 22526 - loss: 4.544164657592773, last 100 batches mean loss: 4.5177
Total processed tokens: 13.3M
Batch 300 / 22526 - loss: 3.4468321800231934, last 100 batches mean loss: 3.7101
Total processed tokens: 19.9M
Batch 400 / 22526 - loss: 3.1616461277008057, last 100 batches mean loss: 3.5338
Total processed tokens: 26.5M
Batch 500 / 22526 - loss: 3.1317455768585205, last 100 batches mean loss: 3.3983
Total processed tokens: 33.1M
Batch 600 / 22526 - loss: 3.0365724563598633, last 100 batches mean loss: 3.3287
Total processed tokens: 39.8M
Batch 700 / 22526 - loss: 3.1562113761901855, last 100 batches mean loss: 3.1331
Total processed tokens: 46.4M
Batch 800 / 22526 - loss: 2.843433141708374, last 100 batches mean loss: 3.0090
Total processed tokens: 53M
Batch 900 / 22526 - loss: 2.657249689102173, last 1

model.safetensors:   0%|          | 0.00/43.9M [00:00<?, ?B/s]

Epoch 0 - mean loss: 1.8404
Finished training! All losses:
[np.float64(1.8403589271031129)]
Total training tokens: 1.49B
Final model saved to ./att_mini/mqa/final_model


model.safetensors:   0%|          | 0.00/43.9M [00:00<?, ?B/s]

Model uploaded to repo: ReactiveAI/MQA-MAT-m


'Total time: 260.83, Time per step/batch: 0.6947'

### Train GMA Mini model

In [None]:
# selected model: GMA
decoder = gma_decoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 128
epochs = 1
gradient_acc_steps = 1

peak_lr = 5e-4 * gradient_acc_steps

subset_len = int(len(wiki_train_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(decoder.params_count())
print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
# warmup_steps = int(0.1 * total_steps)
warmup_steps = int(0.25 * steps_per_epoch)


logs_dir = './att_mini/tensorboard_logs/gma'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./att_mini/gma', push_to_hub=True,
                            hub_model_id='ReactiveAI/GMAT-m', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='GroupedMoeAttentionTransformer Mini v1')

train_dataset = AutoregressiveLMDataset(wiki_train_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(wiki_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, gradient_accumulation_steps=gradient_acc_steps, use_moe_aux_loss=True, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

cuda
Model params 12M
Total steps per epoch: 22526
Start epoch: 0
Batch 100 / 22526 - loss: 6.356077194213867, last 100 batches mean loss: 8.3363
Total processed tokens: 6.66M
Batch 200 / 22526 - loss: 4.505393981933594, last 100 batches mean loss: 4.8644
Total processed tokens: 13.3M
Batch 300 / 22526 - loss: 3.7619831562042236, last 100 batches mean loss: 4.2513
Total processed tokens: 19.9M
Batch 400 / 22526 - loss: 3.2583963871002197, last 100 batches mean loss: 3.5277
Total processed tokens: 26.5M
Batch 500 / 22526 - loss: 3.597238540649414, last 100 batches mean loss: 3.3934
Total processed tokens: 33.1M
Batch 600 / 22526 - loss: 3.204205274581909, last 100 batches mean loss: 3.2725
Total processed tokens: 39.7M
Batch 700 / 22526 - loss: 3.045901298522949, last 100 batches mean loss: 3.1607
Total processed tokens: 46.3M
Batch 800 / 22526 - loss: 2.983999490737915, last 100 batches mean loss: 3.0331
Total processed tokens: 53M
Batch 900 / 22526 - loss: 3.089254379272461, last 100 

model.safetensors:   0%|          | 0.00/48.0M [00:00<?, ?B/s]

Epoch 0 - mean loss: 1.9087
Finished training! All losses:
[np.float64(1.9087171981299442)]
Total training tokens: 1.49B
Final model saved to ./att_mini/gma/final_model


model.safetensors:   0%|          | 0.00/48.0M [00:00<?, ?B/s]

Model uploaded to repo: ReactiveAI/GMAT-m


'Total time: 272.66, Time per step/batch: 0.7262'

### Train DMA Mini model

In [None]:
# selected model: DMA
decoder = dma_decoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 128
epochs = 1
gradient_acc_steps = 1

peak_lr = 5e-4 * gradient_acc_steps

subset_len = int(len(wiki_train_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(decoder.params_count())
print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
# warmup_steps = int(0.1 * total_steps)
warmup_steps = int(0.25 * steps_per_epoch)


logs_dir = './att_mini/tensorboard_logs/dma'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./att_mini/dma', push_to_hub=True,
                            hub_model_id='ReactiveAI/DMAT-m', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='DeepMoeAttentionTransformer Mini v1')

train_dataset = AutoregressiveLMDataset(wiki_train_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(wiki_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, gradient_accumulation_steps=gradient_acc_steps, use_moe_aux_loss=True, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

cuda
Model params 11.8M
Total steps per epoch: 22526
Start epoch: 0
Batch 100 / 22526 - loss: 6.498315811157227, last 100 batches mean loss: 8.4733
Total processed tokens: 6.66M
Batch 200 / 22526 - loss: 4.497570037841797, last 100 batches mean loss: 4.9042
Total processed tokens: 13.3M
Batch 300 / 22526 - loss: 3.783238649368286, last 100 batches mean loss: 4.2527
Total processed tokens: 19.9M
Batch 400 / 22526 - loss: 3.2575321197509766, last 100 batches mean loss: 3.5387
Total processed tokens: 26.5M
Batch 500 / 22526 - loss: 3.601799249649048, last 100 batches mean loss: 3.3987
Total processed tokens: 33.1M
Batch 600 / 22526 - loss: 3.2021563053131104, last 100 batches mean loss: 3.2755
Total processed tokens: 39.7M
Batch 700 / 22526 - loss: 3.0539073944091797, last 100 batches mean loss: 3.1659
Total processed tokens: 46.3M
Batch 800 / 22526 - loss: 2.9873478412628174, last 100 batches mean loss: 3.0360
Total processed tokens: 53M
Batch 900 / 22526 - loss: 3.0921452045440674, last

model.safetensors:   0%|          | 0.00/47.1M [00:00<?, ?B/s]

Epoch 0 - mean loss: 1.9761
Finished training! All losses:
[np.float64(1.9761330786158602)]
Total training tokens: 1.49B
Final model saved to ./att_mini/dma/final_model


model.safetensors:   0%|          | 0.00/47.1M [00:00<?, ?B/s]

Model uploaded to repo: ReactiveAI/DMAT-m


'Total time: 266.88, Time per step/batch: 0.7109'

## Mixture-of-Experts Attention - Next round

> Additional epoch to finally confirm that **MoE Attention** is not worth using

> Unfortunately Mixture-of-Experts Attention has worse results than classic GQA/MQA, so we abandoned this study and proceed to **Sparse Query Attention** research.

### GQA

In [None]:
next_dataset = _rest_dataset

In [None]:
# selected model: GQA
decoder = MoeAttentionTransformer.from_pretrained('ReactiveAI/GQA-MAT-m')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 128
epochs = 1
gradient_acc_steps = 1

peak_lr = 3e-4 * gradient_acc_steps

subset_len = int(len(next_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(decoder.params_count())
print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
# warmup_steps = int(0.1 * total_steps)
warmup_steps = 0


logs_dir = './att_mini/tensorboard_logs/gqa_rest'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./att_mini/gqa_rest', push_to_hub=True,
                            hub_model_id='ReactiveAI/GQA-MAT-m', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='GQA Mini Ref Transformer v2')

train_dataset = AutoregressiveLMDataset(next_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(wiki_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, gradient_accumulation_steps=gradient_acc_steps, use_moe_aux_loss=False, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

config.json:   0%|          | 0.00/465 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/44.7M [00:00<?, ?B/s]

cuda
Model params 11.2M
Total steps per epoch: 25029
Start epoch: 0
Batch 100 / 25029 - loss: 1.6770919561386108, last 100 batches mean loss: 1.5599
Total processed tokens: 6.62M
Batch 200 / 25029 - loss: 1.776141881942749, last 100 batches mean loss: 1.5468
Total processed tokens: 13.1M
Batch 300 / 25029 - loss: 1.5352041721343994, last 100 batches mean loss: 1.5562
Total processed tokens: 19.7M
Batch 400 / 25029 - loss: 1.6222599744796753, last 100 batches mean loss: 1.5246
Total processed tokens: 26.1M
Batch 500 / 25029 - loss: 1.3867709636688232, last 100 batches mean loss: 1.5391
Total processed tokens: 32.5M
Batch 600 / 25029 - loss: 1.4912333488464355, last 100 batches mean loss: 1.5346
Total processed tokens: 39M
Batch 700 / 25029 - loss: 1.3765363693237305, last 100 batches mean loss: 1.5369
Total processed tokens: 45.5M
Batch 800 / 25029 - loss: 1.4580144882202148, last 100 batches mean loss: 1.5448
Total processed tokens: 52M
Batch 900 / 25029 - loss: 1.4422554969787598, las

model.safetensors:   0%|          | 0.00/44.7M [00:00<?, ?B/s]

Epoch 0 - mean loss: 1.4884
Finished training! All losses:
[np.float64(1.4883932043064894)]
Total training tokens: 1.63B
Final model saved to ./att_mini/gqa_rest/final_model


model.safetensors:   0%|          | 0.00/44.7M [00:00<?, ?B/s]

Model uploaded to repo: ReactiveAI/GQA-MAT-m


'Total time: 288.15, Time per step/batch: 0.6908'

### GMA

In [None]:
# selected model: GMA
decoder = MoeAttentionTransformer.from_pretrained('ReactiveAI/GMAT-m')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 128
epochs = 1
gradient_acc_steps = 1

peak_lr = 3e-4 * gradient_acc_steps

subset_len = int(len(next_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(decoder.params_count())
print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
# warmup_steps = int(0.1 * total_steps)
warmup_steps = 0


logs_dir = './att_mini/tensorboard_logs/gma_rest'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./att_mini/gma_rest', push_to_hub=True,
                            hub_model_id='ReactiveAI/GMAT-m', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='GroupedMoeAttentionTransformer Mini v2')

train_dataset = AutoregressiveLMDataset(next_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(wiki_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, gradient_accumulation_steps=gradient_acc_steps, use_moe_aux_loss=True, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

config.json:   0%|          | 0.00/465 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/48.0M [00:00<?, ?B/s]

cuda
Model params 12M
Total steps per epoch: 25029
Start epoch: 0
Batch 100 / 25029 - loss: 1.5934804677963257, last 100 batches mean loss: 1.6427
Total processed tokens: 6.56M
Batch 200 / 25029 - loss: 1.7523225545883179, last 100 batches mean loss: 1.6270
Total processed tokens: 13M
Batch 300 / 25029 - loss: 1.646765947341919, last 100 batches mean loss: 1.6417
Total processed tokens: 19.5M
Batch 400 / 25029 - loss: 1.835755705833435, last 100 batches mean loss: 1.6486
Total processed tokens: 26M
Batch 500 / 25029 - loss: 1.5902632474899292, last 100 batches mean loss: 1.6372
Total processed tokens: 32.5M
Batch 600 / 25029 - loss: 1.8100734949111938, last 100 batches mean loss: 1.6591
Total processed tokens: 39.1M
Batch 700 / 25029 - loss: 1.7010889053344727, last 100 batches mean loss: 1.6280
Total processed tokens: 45.6M
Batch 800 / 25029 - loss: 1.5895072221755981, last 100 batches mean loss: 1.6365
Total processed tokens: 52.1M
Batch 900 / 25029 - loss: 1.548598289489746, last 10

model.safetensors:   0%|          | 0.00/48.0M [00:00<?, ?B/s]

Epoch 0 - mean loss: 1.5677
Finished training! All losses:
[np.float64(1.5677252254674685)]
Total training tokens: 1.63B
Final model saved to ./att_mini/gma_rest/final_model


model.safetensors:   0%|          | 0.00/48.0M [00:00<?, ?B/s]

Model uploaded to repo: ReactiveAI/GMAT-m


'Total time: 304.15, Time per step/batch: 0.7291'

### DMA

In [None]:
# selected model: DMA
decoder = MoeAttentionTransformer.from_pretrained('ReactiveAI/DMAT-m')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 128
epochs = 1
gradient_acc_steps = 1

peak_lr = 3e-4 * gradient_acc_steps

subset_len = int(len(next_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(decoder.params_count())
print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
# warmup_steps = int(0.1 * total_steps)
warmup_steps = 0


logs_dir = './att_mini/tensorboard_logs/dma_rest'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./att_mini/dma_rest', push_to_hub=True,
                            hub_model_id='ReactiveAI/DMAT-m', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='DeepMoeAttentionTransformer Mini v2')

train_dataset = AutoregressiveLMDataset(next_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(wiki_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, gradient_accumulation_steps=gradient_acc_steps, use_moe_aux_loss=True, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

config.json:   0%|          | 0.00/462 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/47.1M [00:00<?, ?B/s]

cuda
Model params 11.8M
Total steps per epoch: 25029
Start epoch: 0
Batch 100 / 25029 - loss: 1.8953702449798584, last 100 batches mean loss: 1.7435
Total processed tokens: 6.61M
Batch 200 / 25029 - loss: 1.6441082954406738, last 100 batches mean loss: 1.7089
Total processed tokens: 13.1M
Batch 300 / 25029 - loss: 1.4566195011138916, last 100 batches mean loss: 1.7250
Total processed tokens: 19.6M
Batch 400 / 25029 - loss: 1.7590229511260986, last 100 batches mean loss: 1.7213
Total processed tokens: 26.1M
Batch 500 / 25029 - loss: 1.7074635028839111, last 100 batches mean loss: 1.6907
Total processed tokens: 32.5M
Batch 600 / 25029 - loss: 1.5931276082992554, last 100 batches mean loss: 1.7095
Total processed tokens: 39M
Batch 700 / 25029 - loss: 2.0504255294799805, last 100 batches mean loss: 1.7316
Total processed tokens: 45.5M
Batch 800 / 25029 - loss: 1.6790380477905273, last 100 batches mean loss: 1.6960
Total processed tokens: 51.9M
Batch 900 / 25029 - loss: 1.6751078367233276, 

model.safetensors:   0%|          | 0.00/47.1M [00:00<?, ?B/s]

Epoch 0 - mean loss: 1.6490
Finished training! All losses:
[np.float64(1.6490000492041463)]
Total training tokens: 1.63B
Final model saved to ./att_mini/dma_rest/final_model


model.safetensors:   0%|          | 0.00/47.1M [00:00<?, ?B/s]

Model uploaded to repo: ReactiveAI/DMAT-m


'Total time: 300.32, Time per step/batch: 0.7199'

## SparseQueryAttention - Mini models test

> Tests for **SQA** variants with other params same as above Mini models

### Sparse Query Attention base

In [None]:
sqa_config = {
    **base_config,
    'att_type': 'sqa',
    'att_num_query_groups': 8
}


sqa_decoder = ExperimentalAttentionTransformer(**sqa_config)
sqa_decoder.params_count()

'Model params 10.7M'

In [None]:
# selected model: SQA
decoder = sqa_decoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 128
epochs = 1
gradient_acc_steps = 1

peak_lr = 5e-4 * gradient_acc_steps

subset_len = int(len(wiki_train_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(decoder.params_count())
print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
# warmup_steps = int(0.1 * total_steps)
warmup_steps = int(0.25 * steps_per_epoch)


logs_dir = './att_mini/tensorboard_logs/sqa'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./att_mini/sqa', push_to_hub=True,
                            hub_model_id='ReactiveAI/SQAT-m', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='SQA Mini Ref Transformer v2')

train_dataset = AutoregressiveLMDataset(wiki_train_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(wiki_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, gradient_accumulation_steps=gradient_acc_steps, use_moe_aux_loss=False, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

cuda
Model params 10.7M
Total steps per epoch: 22526
Start epoch: 0
Batch 100 / 22526 - loss: 5.810866832733154, last 100 batches mean loss: 8.5427
Total processed tokens: 6.69M
Batch 200 / 22526 - loss: 3.5995867252349854, last 100 batches mean loss: 4.7055
Total processed tokens: 13.3M
Batch 300 / 22526 - loss: 3.6262636184692383, last 100 batches mean loss: 3.8414
Total processed tokens: 19.9M
Batch 400 / 22526 - loss: 3.743838310241699, last 100 batches mean loss: 3.5306
Total processed tokens: 26.5M
Batch 500 / 22526 - loss: 3.3322360515594482, last 100 batches mean loss: 3.4150
Total processed tokens: 33.2M
Batch 600 / 22526 - loss: 3.1847569942474365, last 100 batches mean loss: 3.2974
Total processed tokens: 39.8M
Batch 700 / 22526 - loss: 2.9932663440704346, last 100 batches mean loss: 3.1145
Total processed tokens: 46.4M
Batch 800 / 22526 - loss: 2.9160711765289307, last 100 batches mean loss: 3.0289
Total processed tokens: 53M
Batch 900 / 22526 - loss: 2.8893659114837646, la

model.safetensors:   0%|          | 0.00/42.6M [00:00<?, ?B/s]

Epoch 0 - mean loss: 1.8212
Finished training! All losses:
[np.float64(1.8212270385527805)]
Total training tokens: 1.49B
Final model saved to ./att_mini/sqa/final_model


model.safetensors:   0%|          | 0.00/42.6M [00:00<?, ?B/s]

Model uploaded to repo: ReactiveAI/SQAT-m


'Total time: 240.93, Time per step/batch: 0.6417'

### SparseQueryAttention - extreme version

In [None]:
xsqa_config = {
    **base_config,
    'att_type': 'sqa',
    'att_num_query_groups': 4
}


xsqa_decoder = ExperimentalAttentionTransformer(**xsqa_config)
xsqa_decoder.params_count()

'Model params 10.4M'

In [None]:
# selected model: xSQA
decoder = xsqa_decoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 128
epochs = 1
gradient_acc_steps = 1

peak_lr = 5e-4 * gradient_acc_steps

subset_len = int(len(wiki_train_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(decoder.params_count())
print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
# warmup_steps = int(0.1 * total_steps)
warmup_steps = int(0.25 * steps_per_epoch)


logs_dir = './att_mini/tensorboard_logs/xsqa'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./att_mini/xsqa', push_to_hub=True,
                            hub_model_id='ReactiveAI/xSQAT-m', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='xSQA Mini Transformer v1')

train_dataset = AutoregressiveLMDataset(wiki_train_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(wiki_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, gradient_accumulation_steps=gradient_acc_steps, use_moe_aux_loss=False, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

cuda
Model params 10.4M
Total steps per epoch: 22526
Start epoch: 0
Batch 100 / 22526 - loss: 6.067553520202637, last 100 batches mean loss: 8.3499
Total processed tokens: 6.7M
Batch 200 / 22526 - loss: 4.299870014190674, last 100 batches mean loss: 4.6892
Total processed tokens: 13.3M
Batch 300 / 22526 - loss: 3.7969117164611816, last 100 batches mean loss: 3.8338
Total processed tokens: 19.9M
Batch 400 / 22526 - loss: 3.9267807006835938, last 100 batches mean loss: 3.5620
Total processed tokens: 26.6M
Batch 500 / 22526 - loss: 3.215667724609375, last 100 batches mean loss: 3.4129
Total processed tokens: 33.2M
Batch 600 / 22526 - loss: 3.071697950363159, last 100 batches mean loss: 3.2585
Total processed tokens: 39.8M
Batch 700 / 22526 - loss: 3.1588120460510254, last 100 batches mean loss: 3.1518
Total processed tokens: 46.4M
Batch 800 / 22526 - loss: 3.047050952911377, last 100 batches mean loss: 2.9960
Total processed tokens: 53.1M
Batch 900 / 22526 - loss: 2.750436544418335, last 

model.safetensors:   0%|          | 0.00/41.6M [00:00<?, ?B/s]

Epoch 0 - mean loss: 1.8306
Finished training! All losses:
[np.float64(1.8305851418681918)]
Total training tokens: 1.49B
Final model saved to ./att_mini/xsqa/final_model


model.safetensors:   0%|          | 0.00/41.6M [00:00<?, ?B/s]

Model uploaded to repo: ReactiveAI/xSQAT-m


'Total time: 234.69, Time per step/batch: 0.6251'

### Symmetric SQA Variant

In [None]:
ssqa_config = {
    **base_config,
    'att_type': 'sqa',
    'att_num_query_groups': 8,
    'att_groups': 8,
}


ssqa_decoder = ExperimentalAttentionTransformer(**ssqa_config)
ssqa_decoder.params_count()

'Model params 10.9M'

In [None]:
# selected model: sSQA
decoder = ssqa_decoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 128
epochs = 1
gradient_acc_steps = 1

peak_lr = 5e-4 * gradient_acc_steps

subset_len = int(len(wiki_train_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(decoder.params_count())
print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
# warmup_steps = int(0.1 * total_steps)
warmup_steps = int(0.25 * steps_per_epoch)


logs_dir = './att_mini/tensorboard_logs/ssqa'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./att_mini/ssqa', push_to_hub=True,
                            hub_model_id='ReactiveAI/sSQAT-m', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='sSQA Mini Transformer v1')

train_dataset = AutoregressiveLMDataset(wiki_train_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(wiki_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, gradient_accumulation_steps=gradient_acc_steps, use_moe_aux_loss=False, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

cuda
Model params 10.9M
Total steps per epoch: 22526
Start epoch: 0
Batch 100 / 22526 - loss: 5.44687032699585, last 100 batches mean loss: 8.3688
Total processed tokens: 6.71M
Batch 200 / 22526 - loss: 4.104604721069336, last 100 batches mean loss: 4.6433
Total processed tokens: 13.3M
Batch 300 / 22526 - loss: 3.725767135620117, last 100 batches mean loss: 3.8045
Total processed tokens: 20M
Batch 400 / 22526 - loss: 3.5195744037628174, last 100 batches mean loss: 3.5156
Total processed tokens: 26.5M
Batch 500 / 22526 - loss: 2.967064142227173, last 100 batches mean loss: 3.4054
Total processed tokens: 33.1M
Batch 600 / 22526 - loss: 3.0514063835144043, last 100 batches mean loss: 3.2788
Total processed tokens: 39.7M
Batch 700 / 22526 - loss: 3.2906408309936523, last 100 batches mean loss: 3.1211
Total processed tokens: 46.4M
Batch 800 / 22526 - loss: 2.9307174682617188, last 100 batches mean loss: 3.0128
Total processed tokens: 53M
Batch 900 / 22526 - loss: 2.9929447174072266, last 10

model.safetensors:   0%|          | 0.00/43.7M [00:00<?, ?B/s]

Epoch 0 - mean loss: 1.8077
Finished training! All losses:
[np.float64(1.807710154081844)]
Total training tokens: 1.49B
Final model saved to ./att_mini/ssqa/final_model


model.safetensors:   0%|          | 0.00/43.7M [00:00<?, ?B/s]

Model uploaded to repo: ReactiveAI/sSQAT-m


'Total time: 242.82, Time per step/batch: 0.6468'

### Extreme Sparse Multi Query Attention

In [None]:
xsmqa_config = {
    **base_config,
    'att_type': 'sqa',
    'att_num_query_groups': 4,
    'att_groups': 1,
}


xsmqa_decoder = ExperimentalAttentionTransformer(**xsmqa_config)
xsmqa_decoder.params_count()

'Model params 10.2M'

In [None]:
# selected model: xSMQA
decoder = xsmqa_decoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 128
epochs = 1
gradient_acc_steps = 1

peak_lr = 5e-4 * gradient_acc_steps

subset_len = int(len(wiki_train_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(decoder.params_count())
print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
# warmup_steps = int(0.1 * total_steps)
warmup_steps = int(0.25 * steps_per_epoch)


logs_dir = './att_mini/tensorboard_logs/xsmqa'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./att_mini/xsmqa', push_to_hub=True,
                            hub_model_id='ReactiveAI/xSMQAT-m', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='xSMQA Mini Transformer v1')

train_dataset = AutoregressiveLMDataset(wiki_train_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(wiki_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, gradient_accumulation_steps=gradient_acc_steps, use_moe_aux_loss=False, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

cuda
Model params 10.2M
Total steps per epoch: 22526
Start epoch: 0
Batch 100 / 22526 - loss: 5.494544506072998, last 100 batches mean loss: 8.0749
Total processed tokens: 6.65M
Batch 200 / 22526 - loss: 4.144289493560791, last 100 batches mean loss: 4.7138
Total processed tokens: 13.3M
Batch 300 / 22526 - loss: 3.744781732559204, last 100 batches mean loss: 3.8063
Total processed tokens: 19.9M
Batch 400 / 22526 - loss: 3.6983964443206787, last 100 batches mean loss: 3.5376
Total processed tokens: 26.5M
Batch 500 / 22526 - loss: 3.6177046298980713, last 100 batches mean loss: 3.4335
Total processed tokens: 33.2M
Batch 600 / 22526 - loss: 3.2346198558807373, last 100 batches mean loss: 3.3028
Total processed tokens: 39.8M
Batch 700 / 22526 - loss: 2.8668017387390137, last 100 batches mean loss: 3.1315
Total processed tokens: 46.4M
Batch 800 / 22526 - loss: 3.173269510269165, last 100 batches mean loss: 3.0353
Total processed tokens: 53M
Batch 900 / 22526 - loss: 2.874671697616577, last 

model.safetensors:   0%|          | 0.00/40.8M [00:00<?, ?B/s]

Epoch 0 - mean loss: 1.8856
Finished training! All losses:
[np.float64(1.885637110710652)]
Total training tokens: 1.49B
Final model saved to ./att_mini/xsmqa/final_model


model.safetensors:   0%|          | 0.00/40.8M [00:00<?, ?B/s]

Model uploaded to repo: ReactiveAI/xSMQAT-m


'Total time: 234.63, Time per step/batch: 0.6250'

## Mini models computational performance comparison


- Nvidia L40S 48GB GPU
- 200 steps
- 32 batch size
```
Model:   (time per 200 steps, time per single batch)
'GQA':   (3.097075939178467, 0.015485379695892334)
'GQAs':  (3.107764959335327, 0.015538824796676636)
'SQA':   (2.827148914337158, 0.014135744571685791)
'xSQA':  (2.6865649223327637, 0.013432824611663818)
'xSMQA': (2.6817448139190674, 0.013408724069595337)
'sSQA':  (2.9016659259796143, 0.014508329629898072)
'MQA':   (2.898437023162842, 0.014492185115814208)
```
- 128 batch size
```
 Model:  (time per 200 steps, time per single batch)
'GQA':   (19.43833899497986, 0.0971916949748993)
'GQAs':  (19.273487091064453, 0.09636743545532227)
'SQA':   (16.972065210342407, 0.08486032605171204)
'xSQA':  (16.232131004333496, 0.08116065502166749)
'xSMQA': (15.909909963607788, 0.07954954981803894)
'sSQA':  (17.343858003616333, 0.08671929001808167)
'MQA':   (17.349792957305908, 0.08674896478652955)
```

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')
torch.cuda.empty_cache()
print(device)

inp = torch.ones(128, 1024, dtype=torch.int32).to(device)
gqa_decoder = gqa_decoder.to(device)
gqas_decoder = gqas_decoder.to(device)
sqa_decoder = sqa_decoder.to(device)
xsqa_decoder = xsqa_decoder.to(device)
xsmqa_decoder = xsmqa_decoder.to(device)
ssqa_decoder = ssqa_decoder.to(device)
mqa_decoder = mqa_decoder.to(device)

steps = 20
warmup = 100

results = {
    'GQA': (0.0, 0.0),
    'GQAs': (0.0, 0.0),
    'SQA': (0.0, 0.0),
    'xSQA': (0.0, 0.0),
    'xSMQA': (0.0, 0.0),
    'sSQA': (0.0, 0.0),
    'MQA': (0.0, 0.0),
}

with torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16):

  print(inp.dtype)
  for _ in range(warmup):
    _ = gqa_decoder(inp)

  t1 = datetime.timestamp(datetime.now())
  for _ in range(steps):
    _ = gqa_decoder(inp)
  t2 = datetime.timestamp(datetime.now())
  results['GQA'] = (t2 - t1, (t2 - t1) / steps)

  t1 = datetime.timestamp(datetime.now())
  for _ in range(steps):
    _ = gqas_decoder(inp)
  t2 = datetime.timestamp(datetime.now())
  results['GQAs'] = (t2 - t1, (t2 - t1) / steps)

  t1 = datetime.timestamp(datetime.now())
  for _ in range(steps):
    _ = sqa_decoder(inp)
  t2 = datetime.timestamp(datetime.now())
  results['SQA'] = (t2 - t1, (t2 - t1) / steps)

  t1 = datetime.timestamp(datetime.now())
  for _ in range(steps):
    _ = xsqa_decoder(inp)
  t2 = datetime.timestamp(datetime.now())
  results['xSQA'] = (t2 - t1, (t2 - t1) / steps)

  t1 = datetime.timestamp(datetime.now())
  for _ in range(steps):
    _ = xsmqa_decoder(inp)
  t2 = datetime.timestamp(datetime.now())
  results['xSMQA'] = (t2 - t1, (t2 - t1) / steps)

  t1 = datetime.timestamp(datetime.now())
  for _ in range(steps):
    _ = ssqa_decoder(inp)
  t2 = datetime.timestamp(datetime.now())
  results['sSQA'] = (t2 - t1, (t2 - t1) / steps)

  t1 = datetime.timestamp(datetime.now())
  for _ in range(steps):
    _ = ssqa_decoder(inp)
  t2 = datetime.timestamp(datetime.now())
  results['MQA'] = (t2 - t1, (t2 - t1) / steps)

results

cuda


NameError: name 'sqa_decoder' is not defined

# SQA - Micro MoE models full training test

> **SQA** variants tests for micro Mixture-of-Experts architectures, compared to **GQA** and **MQA**

## Init models

In [None]:
embed_dim = 128
vocab_size = 5_000
seq_len = 256

torch.random.manual_seed(2137)
np.random.seed(2137)
random.seed(2137)

base_config = {
    'num_layers': 6,
    'vocab_size': vocab_size,
    'embed_dim': embed_dim,
    'att_heads': 8,
    'att_groups': 2,
    'seq_len': seq_len,
    'use_flash_attention': False,
    'use_gated': True,
    'ff_dropout': 0.1,
    'ff_activation': 'silu',
    'ff_dim': 256,
    'use_moe_ff': True,
    'ff_num_experts': 12,
    'ff_moe_top_k': 2,
}

gqa_config = {
    **base_config,
    'att_type': 'gqa',
}

mqa_config = {
    **base_config,
    'att_type': 'mqa',
}


sqa_config = {
    **base_config,
    'att_type': 'sqa',
    'att_num_query_groups': 4
}

ssqa_config = {
    **base_config,
    'att_type': 'sqa',
    'att_num_query_groups': 4,
    'att_groups': 4,
}

xsqa_config = {
    **base_config,
    'att_type': 'sqa',
    'att_num_query_groups': 2,
}


gqa_decoder = ExperimentalAttentionTransformer(**gqa_config)
mqa_decoder = ExperimentalAttentionTransformer(**mqa_config)
sqa_decoder = ExperimentalAttentionTransformer(**sqa_config)
ssqa_decoder = ExperimentalAttentionTransformer(**ssqa_config)
xsqa_decoder = ExperimentalAttentionTransformer(**xsqa_config)

(('GQA', gqa_decoder.params_count()),
('MQA', mqa_decoder.params_count()),
('SQA', sqa_decoder.params_count()),
('sSQA', ssqa_decoder.params_count()),
('xSQA', xsqa_decoder.params_count()))

(('GQA', 'Model params 8.67M'),
 ('MQA', 'Model params 8.64M'),
 ('SQA', 'Model params 8.57M'),
 ('sSQA', 'Model params 8.62M'),
 ('xSQA', 'Model params 8.52M'))

In [None]:
tr = TokenizerTrainer.from_pretrained('ReactiveAI/RxT-Alpha-Micro-Decoder')
tokenizer = tr.get_hf_tokenizer()
tokenizer

tokenizer.json:   0%|          | 0.00/324k [00:00<?, ?B/s]

PreTrainedTokenizerFast(name_or_path='', vocab_size=5000, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("[BOS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	3: AddedToken("[EOS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	4: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	5000: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	5001: AddedTok

In [None]:
from datasets import load_dataset
stories_dataset = load_dataset('roneneldan/TinyStories', split='train', trust_remote_code=True)
hf_valid_dataset = load_dataset('roneneldan/TinyStories', split='validation', trust_remote_code=True)
len(stories_dataset), len(hf_valid_dataset)

README.md:   0%|          | 0.00/1.06k [00:00<?, ?B/s]

(…)-00000-of-00004-2d5a1467fff1081b.parquet:   0%|          | 0.00/249M [00:00<?, ?B/s]

(…)-00001-of-00004-5852b56a2bd28fd9.parquet:   0%|          | 0.00/248M [00:00<?, ?B/s]

(…)-00002-of-00004-a26307300439e943.parquet:   0%|          | 0.00/246M [00:00<?, ?B/s]

(…)-00003-of-00004-d243063613e5a057.parquet:   0%|          | 0.00/248M [00:00<?, ?B/s]

(…)-00000-of-00001-869c898b519ad725.parquet:   0%|          | 0.00/9.99M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2119719 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/21990 [00:00<?, ? examples/s]

(2119719, 21990)

## GQA

In [None]:
# selected model: GQA
decoder = gqa_decoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 256
epochs = 5
gradient_acc_steps = 1

peak_lr = 2e-3 * gradient_acc_steps

subset_len = int(len(stories_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
# warmup_steps = int(0.1 * total_steps)
warmup_steps = 0


logs_dir = './micro_att/tensorboard_logs/gqa'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./micro_att/gqa', push_to_hub=True,
                            hub_model_id='ReactiveAI/GQA-Ref-Micro', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='GQA Ref Micro Transformer v2')

train_dataset = AutoregressiveLMDataset(stories_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(hf_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, use_moe_aux_loss=True, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

cuda
Total steps per epoch: 8279
Start epoch: 0
Batch 100 / 8279 - loss: 2.629328727722168, last 100 batches mean loss: 3.3715
Total processed tokens: 5.56M
Batch 200 / 8279 - loss: 2.3235981464385986, last 100 batches mean loss: 2.4578
Total processed tokens: 11.1M
Batch 300 / 8279 - loss: 2.015925407409668, last 100 batches mean loss: 2.1836
Total processed tokens: 16.6M
Batch 400 / 8279 - loss: 2.0380570888519287, last 100 batches mean loss: 2.0222
Total processed tokens: 22.1M
Batch 500 / 8279 - loss: 1.8263487815856934, last 100 batches mean loss: 1.9112
Total processed tokens: 27.6M
Batch 600 / 8279 - loss: 1.7749731540679932, last 100 batches mean loss: 1.8324
Total processed tokens: 33.1M
Batch 700 / 8279 - loss: 1.7805384397506714, last 100 batches mean loss: 1.7679
Total processed tokens: 38.6M
Batch 800 / 8279 - loss: 1.6638338565826416, last 100 batches mean loss: 1.7176
Total processed tokens: 44.1M
Batch 900 / 8279 - loss: 1.6863948106765747, last 100 batches mean loss: 1

model.safetensors:   0%|          | 0.00/34.7M [00:00<?, ?B/s]

Epoch 0 - mean loss: 1.4903
Start epoch: 1
Batch 100 / 8279 - loss: 1.3069393634796143, last 100 batches mean loss: 1.3276
Total processed tokens: 461M
Batch 200 / 8279 - loss: 1.342293620109558, last 100 batches mean loss: 1.3204
Total processed tokens: 467M
Batch 300 / 8279 - loss: 1.3354506492614746, last 100 batches mean loss: 1.3275
Total processed tokens: 472M
Batch 400 / 8279 - loss: 1.300283432006836, last 100 batches mean loss: 1.3190
Total processed tokens: 478M
Batch 500 / 8279 - loss: 1.3299413919448853, last 100 batches mean loss: 1.3194
Total processed tokens: 483M
Batch 600 / 8279 - loss: 1.3457027673721313, last 100 batches mean loss: 1.3163
Total processed tokens: 489M
Batch 700 / 8279 - loss: 1.3116440773010254, last 100 batches mean loss: 1.3139
Total processed tokens: 494M
Batch 800 / 8279 - loss: 1.292658805847168, last 100 batches mean loss: 1.3231
Total processed tokens: 500M
Batch 900 / 8279 - loss: 1.3298698663711548, last 100 batches mean loss: 1.3225
Total pr

model.safetensors:   0%|          | 0.00/34.7M [00:00<?, ?B/s]

Epoch 1 - mean loss: 1.2920
Start epoch: 2
Batch 100 / 8279 - loss: 1.2388347387313843, last 100 batches mean loss: 1.2612
Total processed tokens: 917M
Batch 200 / 8279 - loss: 1.2051705121994019, last 100 batches mean loss: 1.2600
Total processed tokens: 922M
Batch 300 / 8279 - loss: 1.2736046314239502, last 100 batches mean loss: 1.2582
Total processed tokens: 928M
Batch 400 / 8279 - loss: 1.2768995761871338, last 100 batches mean loss: 1.2584
Total processed tokens: 934M
Batch 500 / 8279 - loss: 1.2924998998641968, last 100 batches mean loss: 1.2569
Total processed tokens: 939M
Batch 600 / 8279 - loss: 1.2358978986740112, last 100 batches mean loss: 1.2564
Total processed tokens: 945M
Batch 700 / 8279 - loss: 1.277685284614563, last 100 batches mean loss: 1.2606
Total processed tokens: 950M
Batch 800 / 8279 - loss: 1.2278023958206177, last 100 batches mean loss: 1.2610
Total processed tokens: 956M
Batch 900 / 8279 - loss: 1.246586799621582, last 100 batches mean loss: 1.2525
Total p

model.safetensors:   0%|          | 0.00/34.7M [00:00<?, ?B/s]

Epoch 2 - mean loss: 1.2436
Start epoch: 3
Batch 100 / 8279 - loss: 1.2368011474609375, last 100 batches mean loss: 1.2265
Total processed tokens: 1.37B
Batch 200 / 8279 - loss: 1.228379487991333, last 100 batches mean loss: 1.2256
Total processed tokens: 1.38B
Batch 300 / 8279 - loss: 1.2304444313049316, last 100 batches mean loss: 1.2197
Total processed tokens: 1.38B
Batch 400 / 8279 - loss: 1.2461113929748535, last 100 batches mean loss: 1.2249
Total processed tokens: 1.39B
Batch 500 / 8279 - loss: 1.2029244899749756, last 100 batches mean loss: 1.2218
Total processed tokens: 1.39B
Batch 600 / 8279 - loss: 1.2347424030303955, last 100 batches mean loss: 1.2213
Total processed tokens: 1.4B
Batch 700 / 8279 - loss: 1.2236381769180298, last 100 batches mean loss: 1.2258
Total processed tokens: 1.41B
Batch 800 / 8279 - loss: 1.2222919464111328, last 100 batches mean loss: 1.2268
Total processed tokens: 1.41B
Batch 900 / 8279 - loss: 1.2304718494415283, last 100 batches mean loss: 1.2215

model.safetensors:   0%|          | 0.00/34.7M [00:00<?, ?B/s]

Epoch 3 - mean loss: 1.2124
Start epoch: 4
Batch 100 / 8279 - loss: 1.2109521627426147, last 100 batches mean loss: 1.2002
Total processed tokens: 1.83B
Batch 200 / 8279 - loss: 1.2083141803741455, last 100 batches mean loss: 1.1961
Total processed tokens: 1.83B
Batch 300 / 8279 - loss: 1.1494450569152832, last 100 batches mean loss: 1.1985
Total processed tokens: 1.84B
Batch 400 / 8279 - loss: 1.2192821502685547, last 100 batches mean loss: 1.2038
Total processed tokens: 1.84B
Batch 500 / 8279 - loss: 1.2151356935501099, last 100 batches mean loss: 1.1996
Total processed tokens: 1.85B
Batch 600 / 8279 - loss: 1.217910647392273, last 100 batches mean loss: 1.1970
Total processed tokens: 1.86B
Batch 700 / 8279 - loss: 1.1916956901550293, last 100 batches mean loss: 1.1982
Total processed tokens: 1.86B
Batch 800 / 8279 - loss: 1.1587404012680054, last 100 batches mean loss: 1.1994
Total processed tokens: 1.87B
Batch 900 / 8279 - loss: 1.2022173404693604, last 100 batches mean loss: 1.201

model.safetensors:   0%|          | 0.00/34.7M [00:00<?, ?B/s]

Epoch 4 - mean loss: 1.1957
Finished training! All losses:
[np.float64(1.4902646732359117), np.float64(1.2920006471987509), np.float64(1.2435973253659005), np.float64(1.2124340490874461), np.float64(1.1956647237141926)]
Total training tokens: 2.28B
Final model saved to ./micro_att/gqa/final_model


model.safetensors:   0%|          | 0.00/34.7M [00:00<?, ?B/s]

Model uploaded to repo: ReactiveAI/GQA-Ref-Micro


'Total time: 397.71, Time per step/batch: 0.5765'

## MQA

In [None]:
# selected model: MQA
decoder = mqa_decoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 256
epochs = 5
gradient_acc_steps = 1

peak_lr = 2e-3 * gradient_acc_steps

subset_len = int(len(stories_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
# warmup_steps = int(0.1 * total_steps)
warmup_steps = 0


logs_dir = './micro_att/tensorboard_logs/mqa'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./micro_att/mqa', push_to_hub=True,
                            hub_model_id='ReactiveAI/MQA-Ref-Micro', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='MQA Ref Micro Transformer v2')

train_dataset = AutoregressiveLMDataset(stories_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(hf_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, use_moe_aux_loss=True, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

cuda
Total steps per epoch: 8279
Start epoch: 0
Batch 100 / 8279 - loss: 2.672910213470459, last 100 batches mean loss: 3.3908
Total processed tokens: 5.56M
Batch 200 / 8279 - loss: 2.274366855621338, last 100 batches mean loss: 2.4510
Total processed tokens: 11.1M
Batch 300 / 8279 - loss: 2.118898391723633, last 100 batches mean loss: 2.1822
Total processed tokens: 16.6M
Batch 400 / 8279 - loss: 1.9398634433746338, last 100 batches mean loss: 2.0306
Total processed tokens: 22.1M
Batch 500 / 8279 - loss: 1.8738023042678833, last 100 batches mean loss: 1.9220
Total processed tokens: 27.6M
Batch 600 / 8279 - loss: 1.8154195547103882, last 100 batches mean loss: 1.8486
Total processed tokens: 33.1M
Batch 700 / 8279 - loss: 1.7756985425949097, last 100 batches mean loss: 1.7952
Total processed tokens: 38.6M
Batch 800 / 8279 - loss: 1.7227400541305542, last 100 batches mean loss: 1.7465
Total processed tokens: 44.1M
Batch 900 / 8279 - loss: 1.7109880447387695, last 100 batches mean loss: 1.

model.safetensors:   0%|          | 0.00/34.6M [00:00<?, ?B/s]

Epoch 0 - mean loss: 1.5166
Start epoch: 1
Batch 100 / 8279 - loss: 1.3445544242858887, last 100 batches mean loss: 1.3498
Total processed tokens: 461M
Batch 200 / 8279 - loss: 1.4054932594299316, last 100 batches mean loss: 1.3464
Total processed tokens: 467M
Batch 300 / 8279 - loss: 1.332061529159546, last 100 batches mean loss: 1.3438
Total processed tokens: 472M
Batch 400 / 8279 - loss: 1.3296546936035156, last 100 batches mean loss: 1.3438
Total processed tokens: 478M
Batch 500 / 8279 - loss: 1.3651366233825684, last 100 batches mean loss: 1.3413
Total processed tokens: 483M
Batch 600 / 8279 - loss: 1.3486720323562622, last 100 batches mean loss: 1.3420
Total processed tokens: 489M
Batch 700 / 8279 - loss: 1.30489182472229, last 100 batches mean loss: 1.3454
Total processed tokens: 494M
Batch 800 / 8279 - loss: 1.306792974472046, last 100 batches mean loss: 1.3411
Total processed tokens: 500M
Batch 900 / 8279 - loss: 1.3353289365768433, last 100 batches mean loss: 1.3420
Total pro

model.safetensors:   0%|          | 0.00/34.6M [00:00<?, ?B/s]

Epoch 1 - mean loss: 1.3139
Start epoch: 2
Batch 100 / 8279 - loss: 1.2703547477722168, last 100 batches mean loss: 1.2811
Total processed tokens: 917M
Batch 200 / 8279 - loss: 1.2717329263687134, last 100 batches mean loss: 1.2786
Total processed tokens: 922M
Batch 300 / 8279 - loss: 1.3366533517837524, last 100 batches mean loss: 1.2812
Total processed tokens: 928M
Batch 400 / 8279 - loss: 1.3056503534317017, last 100 batches mean loss: 1.2836
Total processed tokens: 933M
Batch 500 / 8279 - loss: 1.2707017660140991, last 100 batches mean loss: 1.2776
Total processed tokens: 939M
Batch 600 / 8279 - loss: 1.265843391418457, last 100 batches mean loss: 1.2847
Total processed tokens: 945M
Batch 700 / 8279 - loss: 1.3024604320526123, last 100 batches mean loss: 1.2777
Total processed tokens: 950M
Batch 800 / 8279 - loss: 1.307152271270752, last 100 batches mean loss: 1.2811
Total processed tokens: 956M
Batch 900 / 8279 - loss: 1.275983214378357, last 100 batches mean loss: 1.2780
Total pr

model.safetensors:   0%|          | 0.00/34.6M [00:00<?, ?B/s]

Epoch 2 - mean loss: 1.2642
Start epoch: 3
Batch 100 / 8279 - loss: 1.2537927627563477, last 100 batches mean loss: 1.2442
Total processed tokens: 1.37B
Batch 200 / 8279 - loss: 1.227031946182251, last 100 batches mean loss: 1.2445
Total processed tokens: 1.38B
Batch 300 / 8279 - loss: 1.2426995038986206, last 100 batches mean loss: 1.2410
Total processed tokens: 1.38B
Batch 400 / 8279 - loss: 1.2456775903701782, last 100 batches mean loss: 1.2442
Total processed tokens: 1.39B
Batch 500 / 8279 - loss: 1.261172890663147, last 100 batches mean loss: 1.2431
Total processed tokens: 1.39B
Batch 600 / 8279 - loss: 1.194532871246338, last 100 batches mean loss: 1.2389
Total processed tokens: 1.4B
Batch 700 / 8279 - loss: 1.2443511486053467, last 100 batches mean loss: 1.2384
Total processed tokens: 1.41B
Batch 800 / 8279 - loss: 1.2835891246795654, last 100 batches mean loss: 1.2414
Total processed tokens: 1.41B
Batch 900 / 8279 - loss: 1.2461153268814087, last 100 batches mean loss: 1.2383
T

model.safetensors:   0%|          | 0.00/34.6M [00:00<?, ?B/s]

Epoch 3 - mean loss: 1.2327
Start epoch: 4
Batch 100 / 8279 - loss: 1.1744719743728638, last 100 batches mean loss: 1.2216
Total processed tokens: 1.83B
Batch 200 / 8279 - loss: 1.269881248474121, last 100 batches mean loss: 1.2177
Total processed tokens: 1.83B
Batch 300 / 8279 - loss: 1.2106077671051025, last 100 batches mean loss: 1.2181
Total processed tokens: 1.84B
Batch 400 / 8279 - loss: 1.2338532209396362, last 100 batches mean loss: 1.2213
Total processed tokens: 1.84B
Batch 500 / 8279 - loss: 1.221483588218689, last 100 batches mean loss: 1.2185
Total processed tokens: 1.85B
Batch 600 / 8279 - loss: 1.2366887331008911, last 100 batches mean loss: 1.2161
Total processed tokens: 1.86B
Batch 700 / 8279 - loss: 1.2793567180633545, last 100 batches mean loss: 1.2157
Total processed tokens: 1.86B
Batch 800 / 8279 - loss: 1.2040153741836548, last 100 batches mean loss: 1.2180
Total processed tokens: 1.87B
Batch 900 / 8279 - loss: 1.2570245265960693, last 100 batches mean loss: 1.2198

model.safetensors:   0%|          | 0.00/34.6M [00:00<?, ?B/s]

Epoch 4 - mean loss: 1.2152
Finished training! All losses:
[np.float64(1.5166365797105044), np.float64(1.3138795965679602), np.float64(1.2642069910459472), np.float64(1.2327262416409985), np.float64(1.2152393094295464)]
Total training tokens: 2.28B
Final model saved to ./micro_att/mqa/final_model


model.safetensors:   0%|          | 0.00/34.6M [00:00<?, ?B/s]

Model uploaded to repo: ReactiveAI/MQA-Ref-Micro


'Total time: 398.65, Time per step/batch: 0.5778'

## SQA

In [None]:
# selected model: SQA
decoder = sqa_decoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 256
epochs = 5
gradient_acc_steps = 1

peak_lr = 2e-3 * gradient_acc_steps

subset_len = int(len(stories_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
# warmup_steps = int(0.1 * total_steps)
warmup_steps = 0


logs_dir = './micro_att/tensorboard_logs/sqa'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./micro_att/sqa', push_to_hub=True,
                            hub_model_id='ReactiveAI/SQAT-mm', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='SQA-mm Transformer v1.0.0')

train_dataset = AutoregressiveLMDataset(stories_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(hf_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, use_moe_aux_loss=True, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

cuda
Total steps per epoch: 8279
Start epoch: 0
Batch 100 / 8279 - loss: 2.6336724758148193, last 100 batches mean loss: 3.3759
Total processed tokens: 5.55M
Batch 200 / 8279 - loss: 2.282727003097534, last 100 batches mean loss: 2.4487
Total processed tokens: 11M
Batch 300 / 8279 - loss: 2.0732598304748535, last 100 batches mean loss: 2.1834
Total processed tokens: 16.5M
Batch 400 / 8279 - loss: 1.9777555465698242, last 100 batches mean loss: 2.0304
Total processed tokens: 22.1M
Batch 500 / 8279 - loss: 1.8751907348632812, last 100 batches mean loss: 1.9223
Total processed tokens: 27.6M
Batch 600 / 8279 - loss: 1.8109139204025269, last 100 batches mean loss: 1.8445
Total processed tokens: 33.1M
Batch 700 / 8279 - loss: 1.7999869585037231, last 100 batches mean loss: 1.7883
Total processed tokens: 38.6M
Batch 800 / 8279 - loss: 1.7318172454833984, last 100 batches mean loss: 1.7492
Total processed tokens: 44.1M
Batch 900 / 8279 - loss: 1.701560139656067, last 100 batches mean loss: 1.7

model.safetensors:   0%|          | 0.00/34.3M [00:00<?, ?B/s]

Epoch 0 - mean loss: 1.5148
Start epoch: 1
Batch 100 / 8279 - loss: 1.335120677947998, last 100 batches mean loss: 1.3496
Total processed tokens: 461M
Batch 200 / 8279 - loss: 1.3821160793304443, last 100 batches mean loss: 1.3449
Total processed tokens: 467M
Batch 300 / 8279 - loss: 1.3524575233459473, last 100 batches mean loss: 1.3477
Total processed tokens: 472M
Batch 400 / 8279 - loss: 1.3709241151809692, last 100 batches mean loss: 1.3436
Total processed tokens: 478M
Batch 500 / 8279 - loss: 1.349852204322815, last 100 batches mean loss: 1.3433
Total processed tokens: 483M
Batch 600 / 8279 - loss: 1.3550982475280762, last 100 batches mean loss: 1.3442
Total processed tokens: 489M
Batch 700 / 8279 - loss: 1.3836008310317993, last 100 batches mean loss: 1.3444
Total processed tokens: 494M
Batch 800 / 8279 - loss: 1.342641830444336, last 100 batches mean loss: 1.3368
Total processed tokens: 500M
Batch 900 / 8279 - loss: 1.3414976596832275, last 100 batches mean loss: 1.3401
Total pr

model.safetensors:   0%|          | 0.00/34.3M [00:00<?, ?B/s]

Epoch 1 - mean loss: 1.3140
Start epoch: 2
Batch 100 / 8279 - loss: 1.3021323680877686, last 100 batches mean loss: 1.2814
Total processed tokens: 917M
Batch 200 / 8279 - loss: 1.2639914751052856, last 100 batches mean loss: 1.2799
Total processed tokens: 922M
Batch 300 / 8279 - loss: 1.2740983963012695, last 100 batches mean loss: 1.2796
Total processed tokens: 928M
Batch 400 / 8279 - loss: 1.273608922958374, last 100 batches mean loss: 1.2807
Total processed tokens: 933M
Batch 500 / 8279 - loss: 1.2711708545684814, last 100 batches mean loss: 1.2808
Total processed tokens: 939M
Batch 600 / 8279 - loss: 1.2767144441604614, last 100 batches mean loss: 1.2783
Total processed tokens: 944M
Batch 700 / 8279 - loss: 1.2677087783813477, last 100 batches mean loss: 1.2783
Total processed tokens: 950M
Batch 800 / 8279 - loss: 1.269129753112793, last 100 batches mean loss: 1.2800
Total processed tokens: 955M
Batch 900 / 8279 - loss: 1.3269160985946655, last 100 batches mean loss: 1.2777
Total p

model.safetensors:   0%|          | 0.00/34.3M [00:00<?, ?B/s]

Epoch 2 - mean loss: 1.2648
Start epoch: 3
Batch 100 / 8279 - loss: 1.224625587463379, last 100 batches mean loss: 1.2449
Total processed tokens: 1.37B
Batch 200 / 8279 - loss: 1.2364628314971924, last 100 batches mean loss: 1.2392
Total processed tokens: 1.38B
Batch 300 / 8279 - loss: 1.2498807907104492, last 100 batches mean loss: 1.2449
Total processed tokens: 1.38B
Batch 400 / 8279 - loss: 1.2510905265808105, last 100 batches mean loss: 1.2383
Total processed tokens: 1.39B
Batch 500 / 8279 - loss: 1.2183938026428223, last 100 batches mean loss: 1.2440
Total processed tokens: 1.39B
Batch 600 / 8279 - loss: 1.217344880104065, last 100 batches mean loss: 1.2455
Total processed tokens: 1.4B
Batch 700 / 8279 - loss: 1.2690200805664062, last 100 batches mean loss: 1.2426
Total processed tokens: 1.41B
Batch 800 / 8279 - loss: 1.2576708793640137, last 100 batches mean loss: 1.2421
Total processed tokens: 1.41B
Batch 900 / 8279 - loss: 1.2017902135849, last 100 batches mean loss: 1.2439
Tot

model.safetensors:   0%|          | 0.00/34.3M [00:00<?, ?B/s]

Epoch 3 - mean loss: 1.2334
Start epoch: 4
Batch 100 / 8279 - loss: 1.2028932571411133, last 100 batches mean loss: 1.2235
Total processed tokens: 1.83B
Batch 200 / 8279 - loss: 1.230832576751709, last 100 batches mean loss: 1.2208
Total processed tokens: 1.83B
Batch 300 / 8279 - loss: 1.2050509452819824, last 100 batches mean loss: 1.2174
Total processed tokens: 1.84B
Batch 400 / 8279 - loss: 1.2056161165237427, last 100 batches mean loss: 1.2166
Total processed tokens: 1.84B
Batch 500 / 8279 - loss: 1.2073464393615723, last 100 batches mean loss: 1.2153
Total processed tokens: 1.85B
Batch 600 / 8279 - loss: 1.219319224357605, last 100 batches mean loss: 1.2202
Total processed tokens: 1.86B
Batch 700 / 8279 - loss: 1.1848393678665161, last 100 batches mean loss: 1.2215
Total processed tokens: 1.86B
Batch 800 / 8279 - loss: 1.2057225704193115, last 100 batches mean loss: 1.2183
Total processed tokens: 1.87B
Batch 900 / 8279 - loss: 1.1695115566253662, last 100 batches mean loss: 1.2185

model.safetensors:   0%|          | 0.00/34.3M [00:00<?, ?B/s]

Epoch 4 - mean loss: 1.2163
Finished training! All losses:
[np.float64(1.5148357996906059), np.float64(1.3139875709290665), np.float64(1.2647621270539104), np.float64(1.2333957281089636), np.float64(1.2163488218173888)]
Total training tokens: 2.28B
Final model saved to ./micro_att/sqa/final_model


model.safetensors:   0%|          | 0.00/34.3M [00:00<?, ?B/s]

Model uploaded to repo: ReactiveAI/SQAT-mm


'Total time: 387.10, Time per step/batch: 0.5611'

## sSQA

In [None]:
# selected model: sSQA
decoder = ssqa_decoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 256
epochs = 5
gradient_acc_steps = 1

peak_lr = 2e-3 * gradient_acc_steps

subset_len = int(len(stories_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
# warmup_steps = int(0.1 * total_steps)
warmup_steps = 0


logs_dir = './micro_att/tensorboard_logs/ssqa'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./micro_att/ssqa', push_to_hub=True,
                            hub_model_id='ReactiveAI/sSQAT-mm', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='sSQA-mm Transformer v1.0.0')

train_dataset = AutoregressiveLMDataset(stories_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(hf_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, use_moe_aux_loss=True, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

cuda
Total steps per epoch: 8279
Start epoch: 0
Batch 100 / 8279 - loss: 2.6886234283447266, last 100 batches mean loss: 3.3739
Total processed tokens: 5.56M
Batch 200 / 8279 - loss: 2.2805237770080566, last 100 batches mean loss: 2.4369
Total processed tokens: 11.1M
Batch 300 / 8279 - loss: 2.148430585861206, last 100 batches mean loss: 2.1593
Total processed tokens: 16.5M
Batch 400 / 8279 - loss: 1.9415762424468994, last 100 batches mean loss: 2.0000
Total processed tokens: 22M
Batch 500 / 8279 - loss: 1.8783180713653564, last 100 batches mean loss: 1.8907
Total processed tokens: 27.5M
Batch 600 / 8279 - loss: 1.8151317834854126, last 100 batches mean loss: 1.8219
Total processed tokens: 33M
Batch 700 / 8279 - loss: 1.730359435081482, last 100 batches mean loss: 1.7635
Total processed tokens: 38.5M
Batch 800 / 8279 - loss: 1.7125948667526245, last 100 batches mean loss: 1.7254
Total processed tokens: 44.1M
Batch 900 / 8279 - loss: 1.695077896118164, last 100 batches mean loss: 1.6907

model.safetensors:   0%|          | 0.00/34.5M [00:00<?, ?B/s]

Epoch 0 - mean loss: 1.4951
Start epoch: 1
Batch 100 / 8279 - loss: 1.3322702646255493, last 100 batches mean loss: 1.3290
Total processed tokens: 461M
Batch 200 / 8279 - loss: 1.3057458400726318, last 100 batches mean loss: 1.3248
Total processed tokens: 467M
Batch 300 / 8279 - loss: 1.3442506790161133, last 100 batches mean loss: 1.3320
Total processed tokens: 472M
Batch 400 / 8279 - loss: 1.3296499252319336, last 100 batches mean loss: 1.3295
Total processed tokens: 478M
Batch 500 / 8279 - loss: 1.3286819458007812, last 100 batches mean loss: 1.3265
Total processed tokens: 483M
Batch 600 / 8279 - loss: 1.3110555410385132, last 100 batches mean loss: 1.3214
Total processed tokens: 489M
Batch 700 / 8279 - loss: 1.3176957368850708, last 100 batches mean loss: 1.3237
Total processed tokens: 494M
Batch 800 / 8279 - loss: 1.3010143041610718, last 100 batches mean loss: 1.3260
Total processed tokens: 500M
Batch 900 / 8279 - loss: 1.2966675758361816, last 100 batches mean loss: 1.3195
Total

model.safetensors:   0%|          | 0.00/34.5M [00:00<?, ?B/s]

Epoch 1 - mean loss: 1.2963
Start epoch: 2
Batch 100 / 8279 - loss: 1.2421655654907227, last 100 batches mean loss: 1.2651
Total processed tokens: 917M
Batch 200 / 8279 - loss: 1.2674124240875244, last 100 batches mean loss: 1.2633
Total processed tokens: 923M
Batch 300 / 8279 - loss: 1.2575300931930542, last 100 batches mean loss: 1.2660
Total processed tokens: 928M
Batch 400 / 8279 - loss: 1.3031741380691528, last 100 batches mean loss: 1.2636
Total processed tokens: 934M
Batch 500 / 8279 - loss: 1.3007197380065918, last 100 batches mean loss: 1.2632
Total processed tokens: 939M
Batch 600 / 8279 - loss: 1.3110626935958862, last 100 batches mean loss: 1.2615
Total processed tokens: 945M
Batch 700 / 8279 - loss: 1.2788128852844238, last 100 batches mean loss: 1.2630
Total processed tokens: 950M
Batch 800 / 8279 - loss: 1.2704474925994873, last 100 batches mean loss: 1.2622
Total processed tokens: 956M
Batch 900 / 8279 - loss: 1.2867733240127563, last 100 batches mean loss: 1.2605
Total

model.safetensors:   0%|          | 0.00/34.5M [00:00<?, ?B/s]

Epoch 2 - mean loss: 1.2473
Start epoch: 3
Batch 100 / 8279 - loss: 1.2150473594665527, last 100 batches mean loss: 1.2287
Total processed tokens: 1.37B
Batch 200 / 8279 - loss: 1.195818543434143, last 100 batches mean loss: 1.2260
Total processed tokens: 1.38B
Batch 300 / 8279 - loss: 1.242146372795105, last 100 batches mean loss: 1.2229
Total processed tokens: 1.38B
Batch 400 / 8279 - loss: 1.2316362857818604, last 100 batches mean loss: 1.2245
Total processed tokens: 1.39B
Batch 500 / 8279 - loss: 1.2184087038040161, last 100 batches mean loss: 1.2252
Total processed tokens: 1.39B
Batch 600 / 8279 - loss: 1.1605932712554932, last 100 batches mean loss: 1.2229
Total processed tokens: 1.4B
Batch 700 / 8279 - loss: 1.229365348815918, last 100 batches mean loss: 1.2253
Total processed tokens: 1.41B
Batch 800 / 8279 - loss: 1.2138029336929321, last 100 batches mean loss: 1.2265
Total processed tokens: 1.41B
Batch 900 / 8279 - loss: 1.1838717460632324, last 100 batches mean loss: 1.2211
T

model.safetensors:   0%|          | 0.00/34.5M [00:00<?, ?B/s]

Epoch 3 - mean loss: 1.2160
Start epoch: 4
Batch 100 / 8279 - loss: 1.1983999013900757, last 100 batches mean loss: 1.2050
Total processed tokens: 1.83B
Batch 200 / 8279 - loss: 1.1833724975585938, last 100 batches mean loss: 1.2001
Total processed tokens: 1.83B
Batch 300 / 8279 - loss: 1.1863914728164673, last 100 batches mean loss: 1.2049
Total processed tokens: 1.84B
Batch 400 / 8279 - loss: 1.199991226196289, last 100 batches mean loss: 1.1976
Total processed tokens: 1.84B
Batch 500 / 8279 - loss: 1.2207093238830566, last 100 batches mean loss: 1.2016
Total processed tokens: 1.85B
Batch 600 / 8279 - loss: 1.201084852218628, last 100 batches mean loss: 1.2009
Total processed tokens: 1.86B
Batch 700 / 8279 - loss: 1.1694027185440063, last 100 batches mean loss: 1.2025
Total processed tokens: 1.86B
Batch 800 / 8279 - loss: 1.16877019405365, last 100 batches mean loss: 1.1991
Total processed tokens: 1.87B
Batch 900 / 8279 - loss: 1.1872689723968506, last 100 batches mean loss: 1.1986
T

model.safetensors:   0%|          | 0.00/34.5M [00:00<?, ?B/s]

Epoch 4 - mean loss: 1.1987
Finished training! All losses:
[np.float64(1.495101890846151), np.float64(1.2963164104931597), np.float64(1.247301292779365), np.float64(1.2160289415414782), np.float64(1.198726304213782)]
Total training tokens: 2.28B
Final model saved to ./micro_att/ssqa/final_model


model.safetensors:   0%|          | 0.00/34.5M [00:00<?, ?B/s]

Model uploaded to repo: ReactiveAI/sSQAT-mm


'Total time: 390.10, Time per step/batch: 0.5654'

## xSQA

In [None]:
# selected model: xSQA
decoder = xsqa_decoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

gc.collect()
if torch.cuda.is_available():
  torch.cuda.empty_cache()

batch_size = 256
epochs = 5
gradient_acc_steps = 1

peak_lr = 2e-3 * gradient_acc_steps

subset_len = int(len(stories_dataset))

steps_per_epoch = int(subset_len / batch_size - 1)

print(f'Total steps per epoch: {steps_per_epoch}')

total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps)
# warmup_steps = int(0.1 * total_steps)
warmup_steps = 0


logs_dir = './micro_att/tensorboard_logs/xsqa'
if not os.path.exists(logs_dir):
  os.makedirs(logs_dir)


print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
count_cb = TokenCounterCallback(8_000_000_000)
acc_cb = PrintAccuracyCallback()

save_cb = ModelSaveCallback('./micro_att/xsqa', push_to_hub=True,
                            hub_model_id='ReactiveAI/xSQAT-mm', private_repo=True,
                            push_checkpoint_weights=True, final_commit_message='xSQA-mm Transformer v1.0.0')

train_dataset = AutoregressiveLMDataset(stories_dataset, tokenizer, max_seq_len=seq_len)
valid_dataset = AutoregressiveLMDataset(hf_valid_dataset, tokenizer, max_seq_len=seq_len)
trainer = AutoregressiveTrainer(decoder, device, dataset=train_dataset, validation_dataset=valid_dataset,
                         vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True,
                         dtype=torch.bfloat16, log_dir=logs_dir, use_moe_aux_loss=True, moe_aux_loss_scale=0.01)
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.01)
scheduler = get_transformer_lr_scheduler(
    optimizer,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

t1 = datetime.timestamp(datetime.now())
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
t2 = datetime.timestamp(datetime.now())

training_time = t2 - t1
time_per_step = training_time / total_steps

f'Total time: {(training_time / 60):.2f}, Time per step/batch: {time_per_step:.4f}'

cuda
Total steps per epoch: 8279
Start epoch: 0
Batch 100 / 8279 - loss: 2.716587543487549, last 100 batches mean loss: 3.3681
Total processed tokens: 5.56M
Batch 200 / 8279 - loss: 2.37518048286438, last 100 batches mean loss: 2.4536
Total processed tokens: 11.1M
Batch 300 / 8279 - loss: 2.0943000316619873, last 100 batches mean loss: 2.1978
Total processed tokens: 16.6M
Batch 400 / 8279 - loss: 2.019547700881958, last 100 batches mean loss: 2.0468
Total processed tokens: 22.1M
Batch 500 / 8279 - loss: 1.8654183149337769, last 100 batches mean loss: 1.9349
Total processed tokens: 27.6M
Batch 600 / 8279 - loss: 1.8388019800186157, last 100 batches mean loss: 1.8599
Total processed tokens: 33.1M
Batch 700 / 8279 - loss: 1.7626006603240967, last 100 batches mean loss: 1.7974
Total processed tokens: 38.6M
Batch 800 / 8279 - loss: 1.7357368469238281, last 100 batches mean loss: 1.7618
Total processed tokens: 44.1M
Batch 900 / 8279 - loss: 1.6965997219085693, last 100 batches mean loss: 1.7

model.safetensors:   0%|          | 0.00/34.1M [00:00<?, ?B/s]

Epoch 0 - mean loss: 1.5248
Start epoch: 1
Batch 100 / 8279 - loss: 1.3481640815734863, last 100 batches mean loss: 1.3635
Total processed tokens: 461M
Batch 200 / 8279 - loss: 1.3774563074111938, last 100 batches mean loss: 1.3589
Total processed tokens: 467M
Batch 300 / 8279 - loss: 1.3714088201522827, last 100 batches mean loss: 1.3589
Total processed tokens: 472M
Batch 400 / 8279 - loss: 1.341439127922058, last 100 batches mean loss: 1.3573
Total processed tokens: 478M
Batch 500 / 8279 - loss: 1.3576428890228271, last 100 batches mean loss: 1.3549
Total processed tokens: 483M
Batch 600 / 8279 - loss: 1.3610756397247314, last 100 batches mean loss: 1.3551
Total processed tokens: 489M
Batch 700 / 8279 - loss: 1.3875148296356201, last 100 batches mean loss: 1.3515
Total processed tokens: 494M
Batch 800 / 8279 - loss: 1.3424731492996216, last 100 batches mean loss: 1.3502
Total processed tokens: 500M
Batch 900 / 8279 - loss: 1.3359014987945557, last 100 batches mean loss: 1.3502
Total 

model.safetensors:   0%|          | 0.00/34.1M [00:00<?, ?B/s]

Epoch 1 - mean loss: 1.3263
Start epoch: 2
Batch 100 / 8279 - loss: 1.2677198648452759, last 100 batches mean loss: 1.2961
Total processed tokens: 917M
Batch 200 / 8279 - loss: 1.3005645275115967, last 100 batches mean loss: 1.2941
Total processed tokens: 922M
Batch 300 / 8279 - loss: 1.3129369020462036, last 100 batches mean loss: 1.2919
Total processed tokens: 928M
Batch 400 / 8279 - loss: 1.2966564893722534, last 100 batches mean loss: 1.2907
Total processed tokens: 933M
Batch 500 / 8279 - loss: 1.2532564401626587, last 100 batches mean loss: 1.2883
Total processed tokens: 939M
Batch 600 / 8279 - loss: 1.2883380651474, last 100 batches mean loss: 1.2936
Total processed tokens: 944M
Batch 700 / 8279 - loss: 1.2727413177490234, last 100 batches mean loss: 1.2918
Total processed tokens: 950M
Batch 800 / 8279 - loss: 1.3212602138519287, last 100 batches mean loss: 1.2899
Total processed tokens: 955M
Batch 900 / 8279 - loss: 1.3180983066558838, last 100 batches mean loss: 1.2959
Total pr

model.safetensors:   0%|          | 0.00/34.1M [00:00<?, ?B/s]

Epoch 2 - mean loss: 1.2775
Start epoch: 3
Batch 100 / 8279 - loss: 1.253278136253357, last 100 batches mean loss: 1.2600
Total processed tokens: 1.37B
Batch 200 / 8279 - loss: 1.2697815895080566, last 100 batches mean loss: 1.2568
Total processed tokens: 1.38B
Batch 300 / 8279 - loss: 1.2816534042358398, last 100 batches mean loss: 1.2566
Total processed tokens: 1.38B
Batch 400 / 8279 - loss: 1.23891282081604, last 100 batches mean loss: 1.2565
Total processed tokens: 1.39B
Batch 500 / 8279 - loss: 1.2512506246566772, last 100 batches mean loss: 1.2542
Total processed tokens: 1.39B
Batch 600 / 8279 - loss: 1.2673354148864746, last 100 batches mean loss: 1.2511
Total processed tokens: 1.4B
Batch 700 / 8279 - loss: 1.2597628831863403, last 100 batches mean loss: 1.2544
Total processed tokens: 1.41B
Batch 800 / 8279 - loss: 1.2705740928649902, last 100 batches mean loss: 1.2565
Total processed tokens: 1.41B
Batch 900 / 8279 - loss: 1.269407033920288, last 100 batches mean loss: 1.2560
To

model.safetensors:   0%|          | 0.00/34.1M [00:00<?, ?B/s]

Epoch 3 - mean loss: 1.2463
Start epoch: 4
Batch 100 / 8279 - loss: 1.219326376914978, last 100 batches mean loss: 1.2337
Total processed tokens: 1.83B
Batch 200 / 8279 - loss: 1.2576916217803955, last 100 batches mean loss: 1.2322
Total processed tokens: 1.83B
Batch 300 / 8279 - loss: 1.262859582901001, last 100 batches mean loss: 1.2308
Total processed tokens: 1.84B
Batch 400 / 8279 - loss: 1.2538000345230103, last 100 batches mean loss: 1.2332
Total processed tokens: 1.84B
Batch 500 / 8279 - loss: 1.2296335697174072, last 100 batches mean loss: 1.2330
Total processed tokens: 1.85B
Batch 600 / 8279 - loss: 1.2404842376708984, last 100 batches mean loss: 1.2296
Total processed tokens: 1.86B
Batch 700 / 8279 - loss: 1.2023570537567139, last 100 batches mean loss: 1.2293
Total processed tokens: 1.86B
Batch 800 / 8279 - loss: 1.171884298324585, last 100 batches mean loss: 1.2312
Total processed tokens: 1.87B
Batch 900 / 8279 - loss: 1.2056052684783936, last 100 batches mean loss: 1.2314


model.safetensors:   0%|          | 0.00/34.1M [00:00<?, ?B/s]

Epoch 4 - mean loss: 1.2289
Finished training! All losses:
[np.float64(1.5248214449571527), np.float64(1.3263190452891271), np.float64(1.2775069807462647), np.float64(1.2463469966742151), np.float64(1.2289010544905916)]
Total training tokens: 2.28B
Final model saved to ./micro_att/xsqa/final_model


model.safetensors:   0%|          | 0.00/34.1M [00:00<?, ?B/s]

Model uploaded to repo: ReactiveAI/xSQAT-mm


'Total time: 382.84, Time per step/batch: 0.5549'

# Longer context computation efficiency

## Init models

In [None]:
torch.random.manual_seed(42)
np.random.seed(42)
random.seed(42)


def create_test_models(seq_len: int = 1024):
  embed_dim = 256
  vocab_size = 10_000

  base_config = {
      'num_layers': 8,
      'vocab_size': vocab_size,
      'embed_dim': embed_dim,
      'att_heads': 16,
      'att_groups': 4,
      'seq_len': seq_len,
      'use_flash_attention': False,
      'use_gated': True,
      'ff_dropout': 0.1,
      'ff_activation': 'silu',
      'ff_dim': 768,
  }

  gqa_config = {
      **base_config,
      'att_type': 'gqa',
  }

  mqa_config = {
      **base_config,
      'att_type': 'mqa',
  }

  mha_config = {
      **base_config,
      'att_type': 'mha',
  }

  sqa_config = {
      **base_config,
      'att_type': 'sqa',
      'att_num_query_groups': 8
  }

  ssqa_config = {
      **base_config,
      'att_type': 'sqa',
      'att_num_query_groups': 8,
      'att_groups': 8,
  }

  xsqa_config = {
      **base_config,
      'att_type': 'sqa',
      'att_num_query_groups': 4
  }

  gqa_decoder = ExperimentalAttentionTransformer(**gqa_config)
  mqa_decoder = ExperimentalAttentionTransformer(**mqa_config)
  mha_decoder = ExperimentalAttentionTransformer(**mha_config)
  sqa_decoder = ExperimentalAttentionTransformer(**sqa_config)
  ssqa_decoder = ExperimentalAttentionTransformer(**ssqa_config)
  xsqa_decoder = ExperimentalAttentionTransformer(**xsqa_config)

  print((
    ('GQA', gqa_decoder.params_count()),
    ('MQA', mqa_decoder.params_count()),
    ('MHA', mha_decoder.params_count()),
    ('SQA', sqa_decoder.params_count()),
    ('sSQA', ssqa_decoder.params_count()),
    ('xSQA', xsqa_decoder.params_count())
  ))

  return {
      'GQA': gqa_decoder,
      'MQA': mqa_decoder,
      'MHA': mha_decoder,
      'SQA': sqa_decoder,
      'sSQA': ssqa_decoder,
      'xSQA': xsqa_decoder,
  }

In [None]:
def time_tests(models: dict[str, ExperimentalAttentionTransformer], batch_size: int = 64, seq_len: int = 1024):
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  if torch.cuda.is_available():
    torch.cuda.empty_cache()

  print(device)

  inp = torch.randint(1, 9999, (batch_size, seq_len), dtype=torch.int32).to(device)
  gqa_decoder = models['GQA'].to(device)
  mqa_decoder = models['MQA'].to(device)
  mha_decoder = models['MHA'].to(device)
  sqa_decoder = models['SQA'].to(device)
  ssqa_decoder = models['sSQA'].to(device)
  xsqa_decoder = models['xSQA'].to(device)


  steps = 50
  warmup = 100

  results = {
      'GQA': (0.0, 0.0),
      'MQA': (0.0, 0.0),
      'MHA': (0.0, 0.0),
      'SQA': (0.0, 0.0),
      'sSQA': (0.0, 0.0),
      'xSQA': (0.0, 0.0),
  }

  with torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16):

    print(inp.dtype)
    for _ in range(warmup):
      _ = gqa_decoder(inp)

    t1 = datetime.timestamp(datetime.now())
    for _ in range(steps):
      _ = gqa_decoder(inp)
    t2 = datetime.timestamp(datetime.now())
    results['GQA'] = (t2 - t1, (t2 - t1) / steps)

    t1 = datetime.timestamp(datetime.now())
    for _ in range(steps):
      _ = mqa_decoder(inp)
    t2 = datetime.timestamp(datetime.now())
    results['MQA'] = (t2 - t1, (t2 - t1) / steps)

    t1 = datetime.timestamp(datetime.now())
    for _ in range(steps):
      _ = mha_decoder(inp)
    t2 = datetime.timestamp(datetime.now())
    results['MHA'] = (t2 - t1, (t2 - t1) / steps)

    t1 = datetime.timestamp(datetime.now())
    for _ in range(steps):
      _ = sqa_decoder(inp)
    t2 = datetime.timestamp(datetime.now())
    results['SQA'] = (t2 - t1, (t2 - t1) / steps)

    t1 = datetime.timestamp(datetime.now())
    for _ in range(steps):
      _ = ssqa_decoder(inp)
    t2 = datetime.timestamp(datetime.now())
    results['sSQA'] = (t2 - t1, (t2 - t1) / steps)

    t1 = datetime.timestamp(datetime.now())
    for _ in range(steps):
      _ = xsqa_decoder(inp)
    t2 = datetime.timestamp(datetime.now())
    results['xSQA'] = (t2 - t1, (t2 - t1) / steps)

  return results

#### 1024 Sequence / 128 batch size / 50 steps

In [None]:
time_tests(create_test_models(1024), 128, 1024)

(('GQA', 'Model params 11.2M'), ('MQA', 'Model params 11M'), ('MHA', 'Model params 12M'), ('SQA', 'Model params 10.7M'), ('sSQA', 'Model params 10.9M'), ('xSQA', 'Model params 10.4M'))
cuda
torch.int32


{'GQA': (5.8207361698150635, 0.2910368084907532),
 'MQA': (5.717133045196533, 0.28585665225982665),
 'MHA': (6.475995779037476, 0.3237997889518738),
 'SQA': (5.27108907699585, 0.2635544538497925),
 'sSQA': (5.335926055908203, 0.26679630279541017),
 'xSQA': (4.76692008972168, 0.23834600448608398)}

#### 4096 Sequence / 32 batch size / 50 steps

In [None]:
time_tests(create_test_models(4096), 32, 4096)

(('GQA', 'Model params 11.2M'), ('MQA', 'Model params 11M'), ('MHA', 'Model params 12M'), ('SQA', 'Model params 10.7M'), ('sSQA', 'Model params 10.9M'), ('xSQA', 'Model params 10.4M'))
cuda
torch.int32


{'GQA': (7.198396921157837, 0.35991984605789185),
 'MQA': (7.093986988067627, 0.3546993494033813),
 'MHA': (7.911501169204712, 0.3955750584602356),
 'SQA': (6.041421175003052, 0.3020710587501526),
 'sSQA': (6.046663045883179, 0.3023331522941589),
 'xSQA': (5.194203853607178, 0.2597101926803589)}

#### 32768 Sequence / 4 batch size / 50 steps

In [None]:
time_tests(create_test_models(32768), 4, 32768)

NameError: name 'time_tests' is not defined

#### 131072 Sequence / 1 batch size / 50 steps

In [None]:
time_tests(create_test_models(131072), 1, 131072)

(('GQA', 'Model params 11.2M'), ('MQA', 'Model params 11M'), ('MHA', 'Model params 12M'), ('SQA', 'Model params 10.7M'), ('sSQA', 'Model params 10.9M'), ('xSQA', 'Model params 10.4M'))
cuda
torch.int32


{'GQA': (60.77245020866394, 3.038622510433197),
 'MQA': (60.59097599983215, 3.0295487999916078),
 'MHA': (61.49839687347412, 3.074919843673706),
 'SQA': (36.01814913749695, 1.8009074568748473),
 'sSQA': (32.996859073638916, 1.649842953681946),
 'xSQA': (20.378159999847412, 1.0189079999923707)}