In [1]:
import numpy as np
from epsilon_transformers.process.transition_matrices import mess3
from epsilon_transformers.process.GHMM import TransitionMatrixGHMM

# Параметры Mess3
T = mess3(x=0.05, a=0.85)

# GHMM над заданной матрицей переходов
ghmm = TransitionMatrixGHMM(T)

# Пошаговая генерация length токенов
length = 20
sequence = list(ghmm.yield_emissions(sequence_len=length))
sequence

[0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 2, 1, 0, 1, 0, 1, 1]

In [5]:
import torch
from epsilon_transformers.training.dataloader import get_dataloader_and_loss_lower_bound_from_process

process_params = {"name": "mess3", "x": 0.05, "a": 0.85}

# n_ctx — длина контекста (включает BOS, если bos=True)
dataloader, loss_lower_bound, d_vocab = get_dataloader_and_loss_lower_bound_from_process(
    process_params=process_params,
    n_ctx=8,
    bos=True,
    batches_per_epoch=1,
    batch_size=1,
    device="cuda",
)

# Получаем один батч и смотрим на последовательности X (вход) и Y (следующий токен)
X, Y = next(iter(dataloader))
print("X:", X.squeeze(0).tolist())
print("Y:", Y.squeeze(0).tolist())
print("loss_lower_bound:", loss_lower_bound.tolist())
print("d_vocab:", d_vocab)

Process initialized successfully!
X: [3, 2, 2, 0, 0, 0, 1, 0]
Y: [2, 2, 0, 0, 0, 1, 0, 0]
loss_lower_bound: [1.0986123085021973, 0.8577190041542053, 0.8133556842803955, 0.7981176376342773, 0.7947399616241455, 0.7937828302383423, 0.793542742729187, 0.7934796810150146]
d_vocab: 4


# Train

In [6]:
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset

from epsilon_transformers.process.transition_matrices import mess3, tom_quantum
from epsilon_transformers.process.GHMM import TransitionMatrixGHMM

# Фиксируем сиды
np.random.seed(0)
torch.manual_seed(0)

# Процессы
T_mess = mess3(x=0.05, a=0.85)          # 3 символа
T_bloch = tom_quantum(alpha=2.5, beta=0.3)  # 4 символа (Bloch Walk)
ghmm_mess = TransitionMatrixGHMM(T_mess)
ghmm_bloch = TransitionMatrixGHMM(T_bloch)

# Настройки
seq_len = 8                    # Рекомендуемое короткое окно
num_seqs = 4096                # Небольшой датасет
vocab_mess, vocab_bloch = 3, 4
vocab_size = vocab_mess * vocab_bloch  # 12

def sample_sequence(ghmm, length):
    return list(ghmm.yield_emissions(sequence_len=length))

def pair_to_token(mess_tok, bloch_tok):
    return mess_tok * vocab_bloch + bloch_tok

# Генерация независимых пар и склейка токенов (декартово произведение покомпонентно во времени)
X_list = []
Y_list = []
for _ in range(num_seqs):
    s_m = sample_sequence(ghmm_mess, seq_len)
    s_b = sample_sequence(ghmm_bloch, seq_len)
    s_c = [pair_to_token(m, b) for m, b in zip(s_m, s_b)]  # 0..11

    x = s_c[:-1]  # вход
    y = s_c[1:]   # цель (next token)
    X_list.append(x)
    Y_list.append(y)

X = torch.tensor(X_list, dtype=torch.long)  # [N, seq_len-1]
Y = torch.tensor(Y_list, dtype=torch.long)  # [N, seq_len-1]

# Трейн/валид сплит
N = X.size(0)
idx = torch.randperm(N)
split = int(0.9 * N)
train_idx, val_idx = idx[:split], idx[split:]

train_ds = TensorDataset(X[train_idx], Y[train_idx])
val_ds = TensorDataset(X[val_idx], Y[val_idx])

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=128, shuffle=False)

vocab_size, X.shape, Y.shape

(12, torch.Size([4096, 7]), torch.Size([4096, 7]))

In [11]:
next(iter(train_loader))

[tensor([[ 4,  9,  9,  8, 11,  8,  4],
         [ 1,  3,  2,  1,  2,  2,  3],
         [ 2,  6,  6,  7,  0,  7,  7],
         [ 3,  3,  0,  5,  3,  0,  9],
         [ 0,  3,  1,  3,  0,  1,  3],
         [ 0,  4,  0,  2,  7,  7,  3],
         [ 1,  2,  3,  3,  3, 11,  9],
         [10,  9,  7, 10,  9, 10,  5],
         [ 7,  7,  4,  7,  5, 11,  7],
         [ 8,  7,  7,  6, 11, 11, 11],
         [ 2,  1,  3,  7,  3,  2,  2],
         [ 0,  8,  4,  8,  0,  0,  3],
         [ 6,  5,  8, 11,  2,  9,  8],
         [11,  3, 11,  8,  8,  3,  0],
         [ 8, 10,  9,  0,  2,  3,  7],
         [ 0,  4,  5,  7,  7,  6,  4],
         [ 3,  3,  0,  2,  2,  3,  0],
         [ 2,  3,  2,  2,  4,  2,  8],
         [ 5,  6,  5,  2,  4,  5,  4],
         [ 5,  4,  6,  4,  2,  0,  1],
         [10,  1,  2,  6,  1, 11,  3],
         [ 5, 10,  7,  0,  4, 10,  4],
         [ 7,  4,  5,  7,  5,  6, 10],
         [ 7,  6,  5,  5,  4,  5,  5],
         [ 6,  1,  2,  1,  4, 11,  2],
         [11, 11,  9,  5,

In [14]:
def extract_transformer_kwargs(cfg):
    model_cfg = cfg["sweep_config"]["model_config"]
    pure_model = cfg["model_config"]
    return {
        "n_layers": model_cfg["n_layers"][0],
        "n_heads": model_cfg["n_heads"][0],
        "d_head":  model_cfg["d_head"][0],
        "d_model": model_cfg["d_head"][0] * model_cfg["n_heads"][0],
        "d_mlp":  4 * (model_cfg["d_head"][0] * model_cfg["n_heads"][0]),
        "act_fn": pure_model["act_fn"],
        "normalization_type": pure_model["normalization_type"],
        "attn_only": pure_model["attn_only"],
        "seed": pure_model["seed"]
    }

In [15]:
import yaml, torch
from transformer_lens import HookedTransformer, HookedTransformerConfig

cfg_path = "/workspace-SR008.nfs2/nachevsky/simplex/epsilon-transformers/configs/experiment_config_transformer_mess3_bloch_hw.yaml"
with open(cfg_path, "r") as f:
    cfg = yaml.safe_load(f)

device = "cuda" if torch.cuda.is_available() else "cpu"
X1, Y1 = next(iter(train_loader))

model = HookedTransformer(
    HookedTransformerConfig(
        n_ctx=X1.shape[1], 
        d_vocab=int(torch.stack([X1.max(), Y1.max()]).max().item()) + 1,
        device=device,
        dtype=getattr(torch, cfg["model_config"]["dtype"]),
        **extract_transformer_kwargs(cfg)
    )
).to(device)

Moving model to device:  cuda


In [21]:
import os

batch_sizes = (128, 256)
amplifiers = (20, 12)


loader_kwargs = dict(
    num_workers=max(4, os.cpu_count()//2),
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=4,
)


train_loader = DataLoader(
    train_ds,
    batch_size=batch_sizes[0]*amplifiers[0],           
    shuffle=True,
    **loader_kwargs
)
val_loader = DataLoader(
    val_ds,
    batch_size=batch_sizes[1]*amplifiers[1],
    shuffle=False,
    **loader_kwargs
)

In [16]:
import torch.nn.functional as F

optimizer = torch.optim.Adam(model.parameters(), lr=cfg["sweep_config"]["train_config"]["learning_rate"][0]*amplifiers[0])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.5, patience=10, threshold=1e-6
)

@torch.no_grad()
def val_loss_ce():
    model.eval()
    tot, denom = 0.0, 0
    for X, Y in val_loader:
        X, Y = X.to(device), Y.to(device).long()
        logits = model(X)  # [B, L, V]
        B, L, V = logits.shape
        loss = F.cross_entropy(logits.reshape(-1, V), Y.reshape(-1), reduction="sum")
        tot += loss.item(); denom += B*L
    return tot / denom

In [18]:
from tqdm.auto import tqdm

EPOCHS = cfg["train_config"]["n_epochs"]
best = float("inf")
pbar = tqdm(range(EPOCHS), desc="train")
WARMUP_STEPS = 1000
global_step = 0
for g in optimizer.param_groups: g['lr'] = 0


for ep in pbar:
    model.train()
    running_loss, steps = 0.0, 0
    for X, Y in train_loader:
        X, Y = X.to(device), Y.to(device).long()
        optimizer.zero_grad(set_to_none=True)
        logits = model(X)
        B, L, V = logits.shape
        loss = F.cross_entropy(logits.reshape(-1, V), Y.reshape(-1), reduction="mean")
        loss.backward(); optimizer.step()
        running_loss += loss.item(); steps += 1

    train_loss = running_loss / max(steps, 1)
    v = val_loss_ce()
    scheduler.step(v)

    if v < best:
        best = v
        torch.save(model.state_dict(), "best_model.pt")

    pbar.set_postfix(train_loss=f"{train_loss:.4f}", val_loss=f"{v:.4f}", lr=f"{optimizer.param_groups[0]['lr']:.2e}")
        

train:   4%| | 861/20000 [04:52<1:48:28,  2.94it/s, lr=1.22e-08, train_loss=2.1724, val_loss=2.2464]


KeyboardInterrupt: 

In [20]:
import os

batch_sizes = (128, 256)
amplifiers = (20, 12)


loader_kwargs = dict(
    num_workers=max(4, os.cpu_count()//2),
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=4,
)


train_loader = DataLoader(
    train_ds,
    batch_size=batch_sizes[0]*amplifiers[0],           
    shuffle=True,
    **loader_kwargs
)
val_loader = DataLoader(
    val_ds,
    batch_size=batch_sizes[1]*amplifiers[1],
    shuffle=False,
    **loader_kwargs
)

# Train

In [1]:
import os

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer, HookedTransformerConfig

from epsilon_transformers.training.generate_data import (
    load_config,
    generate_and_save_data,
    load_process_data,   
    get_process_string, 
)
from extractors import extract_transformer_kwargs, get_caches
from tokenizer import TokenizerMessBloch


loader_kwargs = dict(
    num_workers=max(24, os.cpu_count()),
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=40,
)


def get_dataloaders(dataset, batch_sizes=(128, 256), amplifiers=(20, 12), split_ratio=0.9):
    split_idx = int(len(dataset) * split_ratio)
    train_idx, val_idx = dataset[:split_idx], dataset[split_idx:]

    train_ds = TensorDataset(train_idx[:, :-1], train_idx[:, 1:])
    val_ds = TensorDataset(val_idx[:, :-1], val_idx[:, 1:])

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_sizes[0]*amplifiers[0],           
        shuffle=True,
        **loader_kwargs
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=batch_sizes[1]*amplifiers[1],
        shuffle=False,
        **loader_kwargs
    )
    return train_loader, val_loader


def get_dataset_from_caches(cfg):
    bloch_cache, mess3_cache = get_caches(cfg)
    X_bloch_all = torch.tensor(bloch_cache["transformer_inputs"], dtype=torch.long)
    X_mess3_all = torch.tensor(mess3_cache["transformer_inputs"], dtype=torch.long)
    steps = int(X_bloch_all.shape[0] / X_mess3_all.shape[0])
    X_mess3_all_new = torch.concat([X_mess3_all for _ in range(steps+1)])[:X_bloch_all.shape[0]]
    tokenizer = TokenizerMessBloch()
    return tokenizer.encode(X_mess3_all_new, X_bloch_all)


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
@torch.no_grad()
def val_loss_ce(model, val_loader):
    model.eval()
    tot, denom = 0.0, 0
    for X, Y in val_loader:
        X, Y = X.to(device), Y.to(device).long()
        logits = model(X)  # [B, L, V]
        B, L, V = logits.shape
        loss = F.cross_entropy(logits.reshape(-1, V), Y.reshape(-1), reduction="sum")
        tot += loss.item(); denom += B*L
    return tot / denom


In [3]:
amplifiers = (20, 12)
CFG_PATH = "configs/experiment_config_transformer_mess3_bloch_hw.yaml"
cfg = load_config(CFG_PATH)

device = "cuda" if torch.cuda.is_available() else "cpu"
dataset = get_dataset_from_caches(cfg)

model = HookedTransformer(
    HookedTransformerConfig(
        n_ctx=cfg["model_config"]["n_ctx"]+1, 
        d_vocab=12,
        device=device,
        dtype=getattr(torch, cfg["model_config"]["dtype"]),
        **extract_transformer_kwargs(cfg)
    )
).to(device)




Moving model to device:  cuda


In [None]:
best = float("inf")
pbar = tqdm(range(cfg["train_config"]["n_epochs"]), desc="train")
WARMUP_STEPS = 30
global_step = 0


amplifiers = (200, 300)
train_loader, val_loader = get_dataloaders(dataset, amplifiers=amplifiers)


optimizer = torch.optim.Adam(
    model.parameters(), 
    lr=cfg["sweep_config"]["train_config"]["learning_rate"][0]*amplifiers[0]
    )
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.5, patience=10, threshold=1e-6
)

base_lr = optimizer.param_groups[0]['lr'] 
for g in optimizer.param_groups: g['lr'] = 0


for ep in pbar:
    model.train()
    running_loss, steps = 0.0, 0
    for X, Y in train_loader:
        X, Y = X.to(device), Y.to(device).long()
        optimizer.zero_grad(set_to_none=True)
        logits = model(X)
        B, L, V = logits.shape
        loss = F.cross_entropy(logits.reshape(-1, V), Y.reshape(-1), reduction="mean")
        loss.backward(); optimizer.step()
        running_loss += loss.item(); steps += 1

        global_step += 1
        if global_step <= WARMUP_STEPS:
            wlr = base_lr * (global_step / WARMUP_STEPS)
            for g in optimizer.param_groups: g['lr'] = wlr    
    
    train_loss = running_loss / max(steps, 1)
    val_loss = val_loss_ce(model, val_loader)
    if global_step > WARMUP_STEPS:
        scheduler.step(val_loss)

    # if val_loss < best:
    #     best = val_loss
    #     torch.save(model.state_dict(), "best_model.pt")

    pbar.set_postfix(train_loss=f"{train_loss:.4f}", val_loss=f"{val_loss:.4f}", lr=f"{optimizer.param_groups[0]['lr']:.2e}")


In [None]:
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
import torch
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler

torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision("high")


amplifiers = (1, 1)
train_loader, val_loader = get_dataloaders(dataset, amplifiers=amplifiers)
base_lr = cfg["sweep_config"]["train_config"]["learning_rate"][0]*amplifiers[0]

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=base_lr,
    fused=True if torch.cuda.is_available() else False,  # ускоряет шаг оптимизации
    betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01
)

scaler = GradScaler(enabled=True)
grad_accum_steps = 4


WARMUP_STEPS = 30
def lr_lambda(step):
    if step < WARMUP_STEPS:
        return float(step + 1) / float(WARMUP_STEPS)   # линейный warmup 0→1
    return 1.0

warmup = LambdaLR(optimizer, lr_lambda)
plateau = ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=10, threshold=1e-6)


pbar = tqdm(range(cfg["train_config"]["n_epochs"]), desc="train")
best = float("inf")
global_step = 0

for ep in pbar:
    model.train()
    running_loss, steps = 0.0, 0
    optimizer.zero_grad(set_to_none=True)

    for it, (X, Y) in enumerate(train_loader):
        # важны non_blocking + pin_memory=True в DataLoader
        X = X.to(device, non_blocking=True)
        Y = Y.to(device, non_blocking=True).long()

        with autocast(dtype=torch.bfloat16, enabled=True):  # или torch.float16 при поддержке
            logits = model(X)  # ожидается (B, L, V)
            B, L, V = logits.shape
            loss = F.cross_entropy(logits.reshape(-1, V), Y.reshape(-1), reduction="mean")

        scaler.scale(loss / grad_accum_steps).backward()

        if (it + 1) % grad_accum_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)
            global_step += 1
            warmup.step()  # применяем warmup на каждый "оптимизационный" шаг

        running_loss += loss.item()
        steps += 1

    train_loss = running_loss / max(steps, 1)

    # === Валидация без градиентов и без autocast (или с autocast для скорости) ===
    model.eval()
    with torch.no_grad():
        val_loss = val_loss_ce(model, val_loader)  # убедись, что внутри нет .item() в цикле

    # После warmup — Plateau по вал-лоссу
    if global_step >= WARMUP_STEPS:
        plateau.step(val_loss)

    pbar.set_postfix(
        train_loss=f"{train_loss:.4f}",
        val_loss=f"{val_loss:.4f}",
        lr=f"{optimizer.param_groups[0]['lr']:.2e}"
    )