## Main Model

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"  # must be BEFORE torch/TF import
import warnings
warnings.filterwarnings("ignore")
import torch.optim as optim
import torch.nn.functional as F
from bytelatent.model.blt import ByteLatentTransformerArgs, ByteLatentTransformer
from utils.train_utils import *
from tqdm import tqdm
import time
import numpy as np
from pathlib import Path
import wandb
from bytelatent.tokenizers.constants import PAD_ID
from utils.eval_utils import evaluation

In [None]:
torch.cuda.set_device(0)   # 0 here means "the first visible GPU", i.e. physical #3
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
## Training Args
vocab_size = 4096
quant_range = 15
batch_size = 256
seq_len = 96
learning_rate = 5e-4
weight_decay = 1e-2
epochs = 500  # Increased for early stopping
grad_accumulation_steps = 1
clip_grad = 1.0
seed = 42
warmup_steps = 0
min_lr_factor = 0.1
decay_lr = True
compile = True
output_dir = "output"
save_every = 10
# eval_every = 100  # Evaluate every 5 epochs
patience = 6   # Early stopping patience
compile = True
dataset_name = 'ETTm1'
features = 'M'

In [None]:
# Create output directory
Path(output_dir).mkdir(parents=True, exist_ok=True)

In [None]:
train_dataset, train_loader = build_dataloader(
    dataset_name=dataset_name,
    features=features, 
    seq_len=seq_len, 
    label_len=0, 
    pred_len=96, 
    flag='train', 
    batch_size=batch_size,
    pretrain=True
    )

validate_dataset, validate_loader = build_dataloader(
    dataset_name=dataset_name,
    features=features, 
    seq_len=seq_len, 
    label_len=0, 
    pred_len=96, 
    flag='val', 
    batch_size=batch_size,
    pretrain=True
    )

test_dataset, test_loader = build_dataloader(
    dataset_name=dataset_name,
    features=features, 
    seq_len=seq_len, 
    label_len=0, 
    pred_len=96, 
    flag='test', 
    batch_size=batch_size,
    pretrain=True
    )

print(f"Dataset: {dataset_name}, Features: {features}, Batch Size: {batch_size}, Seq Len: {seq_len}")

# Initialize components
tokenizer = build_tokenizer(
    quant_range=quant_range,
    vocab_size=vocab_size,
    context_length=seq_len,
    prediction_length=96
)

In [None]:
## Set model args
model_args = ByteLatentTransformerArgs(
    seed=42,
    vocab_size=vocab_size,                       # Small byte-level vocab
    max_length=seq_len,                        # Max full sequence length
    max_seqlen=seq_len,
    max_encoder_seq_length=seq_len,
    local_attention_window_len=seq_len,        # Local window, 128 is sufficient for small models

    dim_global=32,                        # Lower than default 512
    dim_local_encoder=16,
    dim_local_decoder=16,

    n_layers_global=4,
    n_layers_local_encoder=4,
    n_layers_local_decoder=4,

    n_heads_global=4,                      # Reduce heads
    n_heads_local_encoder=2,
    n_heads_local_decoder=2,

    patch_size=8,
    patch_in_forward=False,                # Patch in forward pass
    patching_batch_size=256,
    patching_device="cuda",               # Use CPU for patching in small model
    patching_mode="entropy",
    patching_threshold=3.0,
    max_patch_length=16,
    monotonicity=True,            # Monotonic patching
    pad_to_max_length=True,

    cross_attn_encoder=True,
    cross_attn_decoder=True,
    cross_attn_k=2,
    cross_attn_nheads=2,
    cross_attn_all_layers_encoder=True,
    cross_attn_all_layers_decoder=True,
    cross_attn_use_flex_attention=False,
    cross_attn_init_by_pooling=True,

    encoder_hash_byte_group_size=[6,7,8],   # Fewer hash sizes
    encoder_hash_byte_group_vocab=32,
    encoder_hash_byte_group_nb_functions=1,
    encoder_enable_byte_ngrams=False,

    non_linearity="swiglu",
    use_rope=True,
    attn_impl="sdpa",                      # Efficient PyTorch attention
    attn_bias_type="causal",

    dropout=0.0,
    layer_ckpt="none",                     # No checkpointing in small model
    init_use_gaussian=True,
    init_use_depth="current",
    alpha_depth="disabled",
    log_patch_lengths=True,

    downsampling_by_pooling="max",         # Efficient downsampling
    use_local_encoder_transformer=True,
    share_encoder_decoder_emb=True         # Save memory if possible
)

In [None]:
model = ByteLatentTransformer(model_args)
model = model.to(device)
if compile:
    model = torch.compile(model)

# n of params in model in millions
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

model_param_count = count_parameters(model)
print(f"Number of parameters in model: {model_param_count / 1e6:.2f}M")

patch_lengths = create_static_patch_lengths(batch_size=batch_size, seq_len=seq_len) #torch.full((batch_size, 8), 12).to('cuda')
#create_static_patch_lengths(batch_size=batch_size, seq_len=seq_len)

optimizer = optim.AdamW(
    model.parameters(), 
    lr=5e-4, 
    weight_decay=0.01,
    betas=(0.9, 0.95)  # Use better beta values from first code
)
optimizer.zero_grad(set_to_none=True)

torch.manual_seed(model_args.seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(model_args.seed)
torch.set_float32_matmul_precision('high')
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
scaler = torch.amp.GradScaler(enabled=(dtype == 'float16'))
print(f"Using precision: {dtype}")

In [None]:
"""
Training function with early stopping, periodic evaluation, and WandB logging
"""
early_stopping = EarlyStopping(patience=patience, min_delta=1e-6)
logger = TrainingLogger(output_dir, dataset_name, enable_wandb=ENABLE_WANDB)

num_batches = len(train_loader)
total_steps = epochs * num_batches
min_lr = learning_rate * min_lr_factor
best_val_loss = float('inf')

print(f"\n🚀 Starting training with early stopping...")
print(f"📝 Configuration:")
print(f"   Max epochs: {epochs}")
print(f"   Early stopping patience: {patience}")
print(f"   Save every: {save_every} epochs")
print(f"   WandB logging: {'Enabled' if ENABLE_WANDB else 'Disabled'}")

# Training loop
for epoch in range(epochs):
    # Training phase
    model.train()
    t1 = time.time()
    epoch_loss = 0
    current_lr = 0
    batch_losses = []
    
    progress_bar = tqdm(
        enumerate(train_loader), 
        total=len(train_loader), 
        desc=f"🏃 Epoch {epoch+1}/{epochs}", 
        position=0, 
        leave=True
    )
    
    for i, (batch_x, batch_y, _, _) in progress_bar:
        iteration = epoch * num_batches + i
        x = batch_x.float().squeeze(-1)
        y = batch_y.float().squeeze(-1)
        
        # Get learning rate
        lr = get_lr(iteration, total_steps, warmup_steps, learning_rate, min_lr, decay_lr)
        current_lr = lr
        
        # Update learning rate
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
            
        total_loss = 0
        optimizer.zero_grad(set_to_none=True)
        
        # Gradient accumulation loop
        for micro_step in range(grad_accumulation_steps):
            token_ids, attention_mask, tokenizer_state = tokenizer.context_input_transform(x)
            target_token_ids, target_attention_mask = tokenizer.label_input_transform(y, tokenizer_state)
            
            # Forward pass
            logits = model(token_ids.to(device), patch_lengths)
            # MSE Loss
            loss = F.mse_loss(logits, y.to(device), reduction='mean')

            # # Calculate loss
            # loss = F.cross_entropy(
            #     logits.reshape(-1, logits.size(-1)),
            #     target_token_ids.reshape(-1).to(device),
            #     ignore_index=PAD_ID
            # )
            loss = loss / grad_accumulation_steps
            
            # Backward pass
            scaler.scale(loss).backward()
            total_loss += loss.item() * grad_accumulation_steps

        # Gradient clipping
        if clip_grad > 0:
            scaler.unscale_(optimizer)
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad)
            
            # Log gradient norm to wandb periodically
            if ENABLE_WANDB and i % 100 == 0:
                wandb.log({
                    'train/grad_norm': grad_norm,
                    'train/step': iteration,
                    'train/batch_loss': total_loss
                })
        else:
            grad_norm = 0
            
        # Update weights
        scaler.step(optimizer)
        scaler.update()
        
        # Update metrics
        epoch_loss += total_loss
        batch_losses.append(total_loss)
        avg_epoch_loss = epoch_loss / (i + 1)
        
        # Update progress bar
        progress_bar.set_postfix({
            'loss': f"{total_loss:.4f}",
            'avg_loss': f"{avg_epoch_loss:.4f}",
            'lr': f"{lr:.6f}",
            'patience': f"{early_stopping.counter}/{patience}"
        })

    # Calculate training metrics
    train_time = time.time() - t1
    train_avg_loss = epoch_loss / len(train_loader)
    train_std_loss = np.std(batch_losses) if len(batch_losses) > 1 else 0
    
    # Validation phase
    print(f"\n🔍 Running validation for epoch {epoch+1}...")
    t1 = time.time()
    model.eval()
    val_loss = validate(model, validate_loader, tokenizer, patch_lengths, device, 
                        desc=f"Epoch {epoch+1} Validation")
    val_time = time.time() - t1
    
    # Print epoch results
    print(f"\n📊 Epoch {epoch+1}/{epochs} Results:")
    print(f"   Training Loss: {train_avg_loss:.6f} ± {train_std_loss:.6f} (Time: {train_time:.2f}s)")
    print(f"   Validation Loss: {val_loss:.6f} (Time: {val_time:.2f}s)")
    print(f"   Learning Rate: {current_lr:.6f}")
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scaler': scaler.state_dict() if scaler else None,
            'val_loss': val_loss,
            'train_loss': train_avg_loss,
            'model_args': model_args.__dict__
        }
        torch.save(checkpoint, os.path.join(output_dir, f'best_model_{dataset_name}_{features}_{seq_len}.pth'))
        print(f"   ✅ New best model saved! (Val Loss: {best_val_loss:.6f})")
        
    
    # Save periodic checkpoint
    if save_every > 0 and (epoch + 1) % save_every == 0:
        checkpoint_path = os.path.join(output_dir, f'checkpoint_{seq_len}_epoch_{epoch+1}.pth')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scaler': scaler.state_dict() if scaler else None,
            'val_loss': val_loss,
            'train_loss': train_avg_loss,
            'model_args': model_args.__dict__
        }, checkpoint_path)
        print(f"   💾 Checkpoint saved at epoch {epoch+1}")
    
    # Early stopping check
    if early_stopping(val_loss, model):
        print(f"\n🛑 Early stopping triggered after {epoch+1} epochs!")
        print(f"   No improvement for {patience} consecutive epochs")
        print(f"   Best validation loss: {early_stopping.best_loss:.6f}")
        
        if ENABLE_WANDB:
            wandb.run.summary['early_stopped'] = True
            wandb.run.summary['early_stop_epoch'] = epoch + 1
            wandb.run.summary['early_stop_patience'] = patience
        break

# Training completed
print(f"\n🎉 Training completed!")

# Print summary
# logger.print_summary()

# Save final model
final_checkpoint = {
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scaler': scaler.state_dict() if scaler else None,
    'val_loss': val_loss,
    'train_loss': train_avg_loss,
    'early_stopped': early_stopping.counter >= patience,
    'model_args': model_args.__dict__,
    'final_metrics': {
        'best_val_loss': best_val_loss,
        'total_epochs': epoch + 1,
        'total_training_time': time.time()
    }
}
# torch.save(final_checkpoint, os.path.join(output_dir, f'final_model_{dataset_name}_{features}_{seq_len}.pth'))

0.334 with all embeddings in the head with small dimensions.
0.339 all 8 len patch

In [None]:
token_ids.shape, y.shape, target_token_ids.shape, target_attention_mask.shape

In [None]:
# Load and evaluate best model
print("\nEvaluating best model on test set...")
checkpoint = torch.load(os.path.join(output_dir, f"best_model_{dataset_name}_{features}_{seq_len}.pth"))
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Saved Val loss: {checkpoint['val_loss']}")

In [None]:
test_dataset, test_loader = build_dataloader(
    dataset_name=dataset_name,
    features=features,
    seq_len=seq_len,
    label_len=0,
    pred_len=96,
    flag='test',
    batch_size=batch_size,
    pretrain=True
)

# Initialize components
tokenizer = build_tokenizer(
    quant_range=quant_range,
    vocab_size=vocab_size,
    context_length=seq_len,
    prediction_length=96
)

In [None]:
all_mse = 0.0
model.eval()
for i, (batch_x, batch_y, _, _) in enumerate(test_loader):
    x = batch_x.float().squeeze(-1)
    y = batch_y.float().squeeze(-1)

    token_ids, attention_mask, tokenizer_state = tokenizer.context_input_transform(x)
    # target_token_ids, target_attention_mask = tokenizer.label_input_transform(y, tokenizer_state)

    logits = model(token_ids.to(device), patch_lengths)
    loss = F.mse_loss(logits, y.to(device), reduction='mean')
    all_mse += loss.item()

    # print(f"Test Loss: {loss.item():.6f}")
    
# Final evaluation
final_mse = all_mse / len(test_loader)
print(f"\nFinal Test MSE Loss: {final_mse:.6f}")

In [None]:
# Periodic evaluation
eval_results = None

print(f"\n🎯 Running full evaluation at ...")
try:
    eval_results = evaluation(
        model, 
        dataset_name, 
        features,
        quant_range,
        vocab_size,
        input_len=96,
        pred_len=96,
        eval_batch_size=batch_size,
        device=device
    )
    print(f"   📈 Evaluation completed successfully!")
except Exception as e:
    print(f"   ❌ Evaluation failed: {e}")
    eval_results = None