In [None]:
"""
Using EXACT Garg et al. transformer implementation
Simple training loop for your data format
"""

import torch
import torch.nn as nn
from transformers import GPT2Model, GPT2Config
import numpy as np
import time
from tqdm import tqdm
import json
import os

# ============================================================================
# EXACT GARG ET AL. TRANSFORMER (from their repo)
# ============================================================================

class TransformerModel(nn.Module):
    def __init__(self, n_dims, n_positions, n_embd=128, n_layer=12, n_head=4):
        super(TransformerModel, self).__init__()
        configuration = GPT2Config(
            n_positions=2 * n_positions,
            n_embd=n_embd,
            n_layer=n_layer,
            n_head=n_head,
            resid_pdrop=0.0,
            embd_pdrop=0.0,
            attn_pdrop=0.0,
            use_cache=False,
        )
        self.name = f"gpt2_embd={n_embd}_layer={n_layer}_head={n_head}"

        self.n_positions = n_positions
        self.n_dims = n_dims
        self._read_in = nn.Linear(n_dims, n_embd)
        self._backbone = GPT2Model(configuration)
        self._read_out = nn.Linear(n_embd, 1)

    @staticmethod
    def _combine(xs_b, ys_b):
        """Interleaves the x's and the y's into a single sequence."""
        bsize, points, dim = xs_b.shape
        ys_b_wide = torch.cat(
            (
                ys_b.view(bsize, points, 1),
                torch.zeros(bsize, points, dim - 1, device=ys_b.device),
            ),
            axis=2,
        )
        zs = torch.stack((xs_b, ys_b_wide), dim=2)
        zs = zs.view(bsize, 2 * points, dim)
        return zs

    def forward(self, xs, ys, inds=None):
        if inds is None:
            inds = torch.arange(ys.shape[1])
        else:
            inds = torch.tensor(inds)
            if max(inds) >= ys.shape[1] or min(inds) < 0:
                raise ValueError("inds contain indices where xs and ys are not defined")
        zs = self._combine(xs, ys)
        embeds = self._read_in(zs)
        output = self._backbone(inputs_embeds=embeds).last_hidden_state
        prediction = self._read_out(output)
        return prediction[:, ::2, 0][:, inds]  # predict only on xs

# ============================================================================
# YOUR DATA LOADER
# ============================================================================

def load_data(data_path):
    """Load your isotropic_data.npz"""
    data = np.load(data_path)

    # Your format:
    # X_train: (16000, 10, 5)
    # y_train: (16000, 10)
    # X_test: (4000, 10, 5)
    # y_test: (4000, 10)

    train_xs = torch.from_numpy(data['X_train']).float()
    train_ys = torch.from_numpy(data['y_train']).float()
    test_xs = torch.from_numpy(data['X_test']).float()
    test_ys = torch.from_numpy(data['y_test']).float()

    return train_xs, train_ys, test_xs, test_ys

# ============================================================================
# TRAINING LOOP
# ============================================================================

def train_step(model, xs, ys, optimizer, device):
    """Single training step"""
    xs = xs.to(device)
    ys = ys.to(device)

    optimizer.zero_grad()

    # Model predicts at all positions
    predictions = model(xs, ys)

    # MSE loss
    loss = ((predictions - ys) ** 2).mean()

    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()

    return loss.item()

def eval_model(model, xs, ys, device, batch_size=64):
    """Evaluate model on dataset"""
    model.eval()
    total_loss = 0
    n_batches = 0

    with torch.no_grad():
        for i in range(0, len(xs), batch_size):
            batch_xs = xs[i:i+batch_size].to(device)
            batch_ys = ys[i:i+batch_size].to(device)

            predictions = model(batch_xs, batch_ys)
            loss = ((predictions - batch_ys) ** 2).mean()

            total_loss += loss.item()
            n_batches += 1

    model.train()
    return total_loss / n_batches

def train_model(model, train_xs, train_ys, test_xs, test_ys, config, device):
    """Main training loop"""

    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])

    n_train = len(train_xs)
    best_test_loss = float('inf')

    history = {
        'train_loss': [],
        'test_loss': [],
        'step': [],
        'wall_time': []
    }

    start_time = time.time()
    step = 0

    print(f"\n{'='*70}")
    print(f"Training for {config['steps']} steps")
    print(f"Batch size: {config['batch_size']}")
    print(f"Train samples: {n_train}, Test samples: {len(test_xs)}")
    print(f"{'='*70}\n")

    pbar = tqdm(total=config['steps'], desc="Training")

    while step < config['steps']:
        # Sample random batch
        batch_idx = torch.randint(0, n_train, (config['batch_size'],))
        batch_xs = train_xs[batch_idx]
        batch_ys = train_ys[batch_idx]

        # Train step
        train_loss = train_step(model, batch_xs, batch_ys, optimizer, device)

        step += 1
        pbar.update(1)
        pbar.set_postfix({'loss': f'{train_loss:.4f}'})

        # Log every 1000 steps
        if step % 1000 == 0 or step == config['steps']:
            test_loss = eval_model(model, test_xs, test_ys, device, config['batch_size'])
            elapsed_time = time.time() - start_time

            history['train_loss'].append(train_loss)
            history['test_loss'].append(test_loss)
            history['step'].append(step)
            history['wall_time'].append(elapsed_time)

            print(f"\nStep {step:5d} | Train Loss: {train_loss:.6f} | "
                  f"Test Loss: {test_loss:.6f} | Time: {elapsed_time/3600:.2f}h")

            # Save best model
            if test_loss < best_test_loss:
                best_test_loss = test_loss
                torch.save({
                    'step': step,
                    'model_state_dict': model.state_dict(),
                    'test_loss': test_loss,
                    'wall_time': elapsed_time,
                }, config['checkpoint_path'])
                print(f"✓ Saved new best model (test_loss: {test_loss:.6f})")

    pbar.close()

    total_time = time.time() - start_time
    print(f"\n{'='*70}")
    print(f"Training complete!")
    print(f"Best test loss: {best_test_loss:.6f}")
    print(f"Total wall-clock time: {total_time/3600:.2f} hours")
    print(f"{'='*70}\n")

    return history

# ============================================================================
# MAIN
# ============================================================================

if __name__ == "__main__":
    # Set seed
    torch.manual_seed(42)
    np.random.seed(42)

    # Config
    config = {
        # Model params (6-layer as you specified)
        'n_dims': 5,         # your feature dimension
        'n_positions': 10,   # your context points
        'n_embd': 256,
        'n_layer': 6,        # YOUR DEPTH
        'n_head': 4,

        # Training params
        'batch_size': 32,
        'lr': 1e-4,          # Garg et al. use 1e-4
        'steps': 50000,      # 50k steps as in your proposal

        # Paths
        'data_path': 'data/isotropic_data.npz',
        'checkpoint_path': 'checkpoints/transformer_6layer_best.pt',
    }

    os.makedirs('checkpoints', exist_ok=True)
    os.makedirs('results', exist_ok=True)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}\n")

    # Load your data
    print("Loading data...")
    train_xs, train_ys, test_xs, test_ys = load_data(config['data_path'])
    print(f"Train: {train_xs.shape}, Test: {test_xs.shape}")

    # Create model (EXACT Garg et al. architecture)
    print("\nInitializing transformer model...")
    model = TransformerModel(
        n_dims=config['n_dims'],
        n_positions=config['n_positions'],
        n_embd=config['n_embd'],
        n_layer=config['n_layer'],
        n_head=config['n_head']
    )

    n_params = sum(p.numel() for p in model.parameters())
    print(f"Model: {model.name}")
    print(f"Total parameters: {n_params:,}")

    # Train
    print("\nStarting training...")
    history = train_model(
        model,
        train_xs, train_ys,
        test_xs, test_ys,
        config,
        device
    )

    # Save history
    history_path = 'results/training_history_6layer.json'
    with open(history_path, 'w') as f:
        json.dump(history, f, indent=2)

    print(f"\nHistory saved to {history_path}")

Using device: cuda

Loading data...
Train: torch.Size([16000, 10, 5]), Test: torch.Size([4000, 10, 5])

Initializing transformer model...
Model: gpt2_embd=256_layer=6_head=4
Total parameters: 17,611,777

Starting training...

Training for 50000 steps
Batch size: 32
Train samples: 16000, Test samples: 4000



Training:   2%|▏         | 981/50000 [00:25<19:07, 42.71it/s, loss=5.1008]

In [None]:
import json
import os

import matplotlib.pyplot as plt
import numpy as np

def plot_history(history_path, save_dir="results"):
    # Load history JSON
    with open(history_path, "r") as f:
        history = json.load(f)

    train_loss = np.array(history["train_loss"])
    test_loss = np.array(history["test_loss"])
    steps = np.array(history["step"])
    wall_time = np.array(history["wall_time"])  # seconds

    os.makedirs(save_dir, exist_ok=True)

    # 1) Loss vs training steps
    plt.figure()
    plt.plot(steps, train_loss, label="Train loss")
    plt.plot(steps, test_loss, label="Test loss")
    plt.xlabel("Training step")
    plt.ylabel("MSE loss")
    plt.title("Train/Test Loss vs Steps")
    plt.legend()
    plt.grid(True, alpha=0.3)
    loss_step_path = os.path.join(save_dir, "loss_vs_steps.png")
    plt.savefig(loss_step_path, bbox_inches="tight")
    plt.close()
    print(f"Saved {loss_step_path}")

    # 2) Loss vs wall-clock time (in hours)
    wall_time_hours = wall_time / 3600.0

    plt.figure()
    plt.plot(wall_time_hours, train_loss, label="Train loss")
    plt.plot(wall_time_hours, test_loss, label="Test loss")
    plt.xlabel("Wall-clock time (hours)")
    plt.ylabel("MSE loss")
    plt.title("Train/Test Loss vs Wall-Clock Time")
    plt.legend()
    plt.grid(True, alpha=0.3)
    loss_time_path = os.path.join(save_dir, "loss_vs_time.png")
    plt.savefig(loss_time_path, bbox_inches="tight")
    plt.close()
    print(f"Saved {loss_time_path}")

    # 3) Time per step (seconds/step) vs step
    #   Δtime between checkpoints / Δsteps
    dt = np.diff(wall_time)         # seconds between logs
    dsteps = np.diff(steps)         # steps between logs (should be constant, e.g. 1000)
    time_per_step = dt / dsteps     # seconds per step
    step_centers = steps[1:]        # align each dt with the later step

    plt.figure()
    plt.plot(step_centers, time_per_step)
    plt.xlabel("Training step")
    plt.ylabel("Seconds per step")
    plt.title("Time per Step vs Training Step")
    plt.grid(True, alpha=0.3)
    tps_path = os.path.join(save_dir, "time_per_step_vs_steps.png")
    plt.savefig(tps_path, bbox_inches="tight")
    plt.close()
    print(f"Saved {tps_path}")

    # Optionally print summary stats
    print(f"Mean seconds/step: {time_per_step.mean():.6f}")
    print(f"Std seconds/step: {time_per_step.std():.6f}")


if __name__ == "__main__":
    history_path = "results/training_history_6layer.json"  # adjust if needed
    plot_history(history_path, save_dir="results")