In [None]:


import os
import argparse
import math
import random
import json
from datetime import datetime, timedelta
import subprocess
import sys

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# try to import torch and sklearn; if not available, instruct user to install via generated requirements
try:
    import torch
    import torch.nn as nn
    from torch.utils.data import Dataset, DataLoader
except Exception:
    print("PyTorch not available. Please install packages listed in the generated requirements.txt (run run_all.sh).")
    # allow the script to still write helper files even if torch missing
    torch = None
    nn = None
    Dataset = object
    DataLoader = None

try:
    from sklearn.preprocessing import StandardScaler
    from sklearn.metrics import mean_squared_error, mean_absolute_error
except Exception:
    print("scikit-learn not available. Please install packages listed in the generated requirements.txt (run run_all.sh).")
    StandardScaler = None
    mean_squared_error = None
    mean_absolute_error = None

# Local brief images (from your upload)
BRIEF_IMAGE_1 = "/mnt/data/Arun1.png"
BRIEF_IMAGE_2 = "/mnt/data/Arun2.png"

# Output directories
DATA_DIR = "data"
OUT_DIR = "results"
MODEL_DIR = os.path.join(OUT_DIR, "models")
FIG_DIR = os.path.join(OUT_DIR, "figs")
ATTN_DIR = os.path.join(FIG_DIR, "attention")
EXPERIMENTS_DIR = "experiments"
TESTS_DIR = "tests"

for d in [DATA_DIR, OUT_DIR, MODEL_DIR, FIG_DIR, ATTN_DIR, EXPERIMENTS_DIR, TESTS_DIR]:
    os.makedirs(d, exist_ok=True)

# ---------------------------
# Utilities
# ---------------------------
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    if torch is not None:
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def mape(y_true, y_pred):
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    denom = np.where(np.abs(y_true) < 1e-8, 1e-8, np.abs(y_true))
    return np.mean(np.abs((y_true - y_pred) / denom)) * 100

# ---------------------------
# Synthetic dataset
# ---------------------------
def generate_synthetic(length=15000, freq='H', seed=0, out_csv=None):
    """Generate a univariate time series with trend, seasonality, volatility bursts and changepoints."""
    seed_everything(seed)
    start = datetime(2015, 1, 1)
    if freq == 'H':
        dates = [start + timedelta(hours=i) for i in range(length)]
    elif freq == 'D':
        dates = [start + timedelta(days=i) for i in range(length)]
    else:
        dates = [start + timedelta(minutes=i) for i in range(length)]
    t = np.arange(length)
    trend = 0.0002 * (t ** 1.05)
    daily = 3.0 * np.sin(2 * np.pi * (t % 24) / 24)
    weekly = 1.5 * np.sin(2 * np.pi * (t % (24*7)) / (24*7))
    # changepoints
    cps = [int(length * 0.2), int(length * 0.5), int(length * 0.75)]
    cp_effect = np.zeros_like(t, dtype=float)
    for i, cp in enumerate(cps):
        cp_effect += (t > cp) * (0.5 * (i + 1))
    # volatility bursts
    noise = np.random.normal(0, 0.6, size=length)
    bursts = [int(length * 0.35), int(length * 0.6)]
    for c in bursts:
        width = int(length * 0.01) if length>10000 else 200
        window = np.exp(-((t - c) ** 2) / (2 * (width ** 2)))
        noise += window * np.random.normal(0, 8.0, size=length)
    y = 10 + trend + daily + weekly + cp_effect + noise
    df = pd.DataFrame({'ds': dates, 'y': y})
    if out_csv:
        df.to_csv(out_csv, index=False)
    return df

# ---------------------------
# Dataset and DataLoader
# ---------------------------
class TimeSeriesWindowDataset(Dataset):
    def __init__(self, series: np.ndarray, input_len: int, horizon: int):
        # series should be 1D numpy
        if series.ndim == 1:
            series = series.reshape(-1, 1)
        self.series = series.astype(np.float32)
        self.input_len = input_len
        self.horizon = horizon
        self.T = series.shape[0]
        self.n_features = series.shape[1]

    def __len__(self):
        return max(0, self.T - self.input_len - self.horizon + 1)

    def __getitem__(self, idx):
        start = idx
        x = self.series[start: start + self.input_len]
        y = self.series[start + self.input_len: start + self.input_len + self.horizon]
        return x, y

def collate_windows(batch):
    xs = [torch.tensor(b[0]) for b in batch]
    ys = [torch.tensor(b[1]) for b in batch]
    x = torch.stack(xs)  # (B, input_len, feat)
    y = torch.stack(ys)  # (B, horizon, feat)
    return x.permute(0, 2, 1), y.permute(0, 2, 1)  # (B, C, L), (B, C, H)

# ---------------------------
# Robust attention hook registration
# ---------------------------
def register_transformer_attention_hooks(transformer_encoder, attn_container):
    """
    Register forward hooks on TransformerEncoder layers' MultiheadAttention modules.
    Collected tensors appended to attn_container (list).
    """
    # transformer_encoder.layers is a ModuleList of TransformerEncoderLayer
    for i, layer in enumerate(getattr(transformer_encoder, "layers", [])):
        mha = getattr(layer, "self_attn", None)
        if mha is None:
            continue
        def make_hook(idx):
            def hook(module, inp, out):
                # out often is attn_output (Tensor) or tuple where second element is attn_weights.
                attn_weights = None
                if isinstance(out, tuple) and len(out) >= 2:
                    attn_weights = out[1]
                else:
                    # try module.attn_output_weights or module.attn_weights depending on implementation
                    attn_weights = getattr(module, "attn_output_weights", None) or getattr(module, "attn_weights", None)
                if attn_weights is not None:
                    try:
                        attn_container.append(attn_weights.detach().cpu())
                    except Exception:
                        pass
            return hook
        try:
            mha.register_forward_hook(make_hook(i))
        except Exception as e:
            print(f"Warning: failed to register hook on layer {i}: {e}")

# ---------------------------
# Models
# ---------------------------
if torch is not None:
    class SimplePositionalEncoding(nn.Module):
        def __init__(self, d_model, max_len=2000):
            super().__init__()
            pe = torch.zeros(max_len, d_model)
            position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
            div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
            pe[:, 0::2] = torch.sin(position * div_term)
            pe[:, 1::2] = torch.cos(position * div_term)
            pe = pe.unsqueeze(0).transpose(1, 2)  # (1, d_model, max_len)
            self.register_buffer('pe', pe)

        def forward(self, x):
            # x: (B, d_model, L)
            L = x.size(-1)
            return x + self.pe[:, :, :L]

    class TransformerForecaster(nn.Module):
        def __init__(self, input_dim=1, d_model=64, n_heads=4, n_layers=3, d_ff=128,
                     dropout=0.1, horizon=50, input_len=256, pos_max_len=2048):
            super().__init__()
            self.input_dim = input_dim
            self.d_model = d_model
            self.horizon = horizon
            self.input_len = input_len

            self.input_proj = nn.Conv1d(in_channels=input_dim, out_channels=d_model, kernel_size=1)
            self.pos_enc = SimplePositionalEncoding(d_model, max_len=max(pos_max_len, input_len))

            encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads,
                                                       dim_feedforward=d_ff, dropout=dropout,
                                                       activation='gelu', batch_first=True)
            self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
            self.head = nn.Sequential(
                nn.AdaptiveAvgPool1d(1),
                nn.Flatten(),
                nn.Linear(d_model, d_model),
                nn.GELU(),
                nn.Linear(d_model, horizon * input_dim)
            )
            # container for attention weights captured by hooks
            self._attn_weights = []
            # register hooks robustly
            try:
                register_transformer_attention_hooks(self.encoder, self._attn_weights)
            except Exception as e:
                print("Warning: registering attention hooks failed:", e)

        def forward(self, x, return_attn=False):
            # x: (B, C, L)
            B, C, L = x.shape
            x = self.input_proj(x)  # (B, d_model, L)
            x = self.pos_enc(x)
            x_t = x.permute(0, 2, 1)  # (B, L, d_model)
            # clear previously collected weights
            self._attn_weights = []
            enc = self.encoder(x_t)
            enc_p = enc.permute(0, 2, 1)
            out = self.head(enc_p)
            out = out.view(B, self.input_dim, self.horizon)
            if return_attn:
                # return a shallow copy to avoid modification during further forward calls
                return out, list(self._attn_weights)
            else:
                return out

    class LSTMForecaster(nn.Module):
        def __init__(self, input_dim=1, hidden=128, n_layers=2, horizon=50):
            super().__init__()
            self.input_dim = input_dim
            self.horizon = horizon
            self.rnn = nn.LSTM(input_dim, hidden, n_layers, batch_first=True, dropout=0.1)
            self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.GELU(), nn.Linear(hidden, horizon * input_dim))

        def forward(self, x):
            # x: (B, C, L) -> permute to (B, L, C)
            x = x.permute(0, 2, 1)
            out, _ = self.rnn(x)
            last = out[:, -1, :]
            preds = self.head(last).view(x.size(0), self.input_dim, self.horizon)
            return preds

else:
    TransformerForecaster = None
    LSTMForecaster = None

# ---------------------------
# Prophet baseline wrapper (optional)
# ---------------------------
def try_import_prophet():
    try:
        from prophet import Prophet
        return Prophet
    except Exception:
        return None

class ProphetWrapper:
    """
    Lightweight Prophet wrapper. Uses pandas Series indexed by datetimes.
    Note: Prophet fits can be slow when run per sliding window.
    """
    def __init__(self, **kwargs):
        Prophet = try_import_prophet()
        if Prophet is None:
            raise ImportError("prophet not installed. Install via requirements.txt (run_all.sh).")
        self.model = Prophet(**kwargs)

    def fit(self, series: pd.Series):
        df = pd.DataFrame({'ds': series.index, 'y': series.values})
        self.model.fit(df)

    def forecast_periods(self, periods, freq='H'):
        future = self.model.make_future_dataframe(periods=periods, freq=freq)
        fc = self.model.predict(future)
        return fc[['ds', 'yhat']].tail(periods)['yhat'].values

# ---------------------------
# Training + evaluation helpers
# ---------------------------
def train_one_epoch(model, loader, opt, loss_fn, device):
    model.train()
    total_loss = 0.0
    count = 0
    for xb, yb in loader:
        xb = xb.to(device)
        yb = yb.to(device)
        opt.zero_grad()
        if isinstance(model, TransformerForecaster):
            preds, _ = model(xb, return_attn=True)
        else:
            preds = model(xb)
        loss = loss_fn(preds, yb)
        loss.backward()
        opt.step()
        total_loss += loss.item() * xb.size(0)
        count += xb.size(0)
    return total_loss / max(1, count)

def validate(model, loader, loss_fn, device):
    model.eval()
    total_loss = 0.0
    count = 0
    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(device)
            yb = yb.to(device)
            if isinstance(model, TransformerForecaster):
                preds, _ = model(xb, return_attn=True)
            else:
                preds = model(xb)
            loss = loss_fn(preds, yb)
            total_loss += loss.item() * xb.size(0)
            count += xb.size(0)
    return total_loss / max(1, count)

def save_checkpoint(path, model, model_args):
    tosave = {'state_dict': model.state_dict(), 'model_args': model_args}
    torch.save(tosave, path)

def evaluate_model_checkpoint(checkpoint_path, model_cls, scaler, test_series, input_len, horizon, device, batch_size=64):
    ckpt = torch.load(checkpoint_path, map_location=device)
    model = model_cls(**ckpt['model_args']).to(device)
    model.load_state_dict(ckpt['state_dict'])
    model.eval()
    ds = test_series.copy()
    ds_scaled = scaler.transform(ds.reshape(-1, 1)).flatten()
    dataset = TimeSeriesWindowDataset(ds_scaled, input_len=input_len, horizon=horizon)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_windows)
    preds = []
    trues = []
    attn_collected = []
    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(device)
            if isinstance(model, TransformerForecaster):
                py, attn = model(xb, return_attn=True)
                if attn:
                    attn_collected.extend(attn)
            else:
                py = model(xb)
            py = py.cpu().numpy()
            yb = yb.cpu().numpy()
            preds.append(py)
            trues.append(yb)
    preds = np.concatenate(preds, axis=0)
    trues = np.concatenate(trues, axis=0)
    B, C, H = preds.shape
    preds_flat = preds.transpose(0, 2, 1).reshape(-1, C)
    trues_flat = trues.transpose(0, 2, 1).reshape(-1, C)
    preds_inv = scaler.inverse_transform(preds_flat).reshape(-1, H, C).transpose(0, 2, 1)
    trues_inv = scaler.inverse_transform(trues_flat).reshape(-1, H, C).transpose(0, 2, 1)
    return preds_inv, trues_inv, attn_collected

# ---------------------------
# High-level experiment runner
# ---------------------------
def run_experiment(args):
    if torch is None:
        raise RuntimeError("PyTorch is required to train models. Install dependencies via generated requirements.txt and run run_all.sh.")
    device = torch.device('cuda' if torch.cuda.is_available() and not args.no_cuda else 'cpu')
    seed_everything(args.seed)

    # 1) Data
    if args.generate:
        df = generate_synthetic(length=args.length, freq=args.freq, seed=args.seed, out_csv=os.path.join(DATA_DIR, f'synthetic_{args.length}.csv'))
        csv_path = os.path.join(DATA_DIR, f'synthetic_{args.length}.csv')
    else:
        csv_path = args.data_path
        df = pd.read_csv(csv_path)
    series = df['y'].values.astype(float)

    # quick plot first 2000 points
    plt.figure(figsize=(12,3))
    plt.plot(df['ds'].iloc[:2000], series[:2000])
    plt.title('Sample series (first 2000 timesteps)')
    plt.tight_layout()
    plt.savefig(os.path.join(FIG_DIR, 'sample_series.png'))
    plt.close()

    T = len(series)
    train_end = int(T * 0.7)
    val_end = int(T * 0.85)
    train_series = series[:train_end]
    val_series = series[train_end:val_end]
    test_series = series[val_end:]

    scaler = StandardScaler()
    train_scaled = scaler.fit_transform(train_series.reshape(-1,1)).flatten()
    val_scaled = scaler.transform(val_series.reshape(-1,1)).flatten()
    test_scaled = scaler.transform(test_series.reshape(-1,1)).flatten()

    # persist scaler
    with open(os.path.join(OUT_DIR, 'scaler.json'), 'w') as f:
        json.dump({'mean': float(scaler.mean_[0]), 'scale': float(scaler.scale_[0])}, f)

    # datasets
    input_len = args.input_len
    horizon = args.horizon
    train_ds = TimeSeriesWindowDataset(train_scaled, input_len, horizon)
    val_ds = TimeSeriesWindowDataset(val_scaled, input_len, horizon)
    test_ds = TimeSeriesWindowDataset(test_scaled, input_len, horizon)

    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, collate_fn=collate_windows)
    val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_windows)
    test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_windows)

    # model
    if args.model == 'transformer':
        model = TransformerForecaster(input_dim=1, d_model=args.d_model, n_heads=args.n_heads,
                                      n_layers=args.n_layers, d_ff=args.d_ff, dropout=args.dropout,
                                      horizon=horizon, input_len=input_len)
        model_args = {'input_dim':1, 'd_model':args.d_model, 'n_heads':args.n_heads, 'n_layers':args.n_layers,
                      'd_ff':args.d_ff, 'dropout':args.dropout, 'horizon':horizon, 'input_len':input_len}
    else:
        model = LSTMForecaster(input_dim=1, hidden=args.hidden, n_layers=args.n_layers, horizon=horizon)
        model_args = {'input_dim':1, 'hidden':args.hidden, 'n_layers':args.n_layers, 'horizon':horizon}

    model = model.to(device)
    opt = torch.optim.Adam(model.parameters(), lr=args.lr)
    loss_fn = nn.MSELoss()

    best_val = float('inf')
    best_path = None
    history = []

    for epoch in range(1, args.epochs + 1):
        tr_loss = train_one_epoch(model, train_loader, opt, loss_fn, device)
        val_loss = validate(model, val_loader, loss_fn, device)
        history.append({'epoch': epoch, 'train_loss': tr_loss, 'val_loss': val_loss})
        print(f"Epoch {epoch} â€” train_loss {tr_loss:.6f}, val_loss {val_loss:.6f}")

        if val_loss < best_val:
            best_val = val_loss
            best_path = os.path.join(MODEL_DIR, f'best_{args.model}.pth')
            save_checkpoint(best_path, model, model_args)
            print("Saved best model to", best_path)

        # simple early stopping
        if epoch > args.patience and all(h['val_loss'] >= history[-args.patience]['val_loss'] for h in history[-args.patience:]):
            print("Early stopping triggered")
            break

    pd.DataFrame(history).to_csv(os.path.join(OUT_DIR, f'train_history_{args.model}.csv'), index=False)

    # Evaluate across multiple horizons
    metrics_records = []
    horizons_to_eval = [10, 50, 100]
    for h in horizons_to_eval:
        print("Evaluating horizon", h)
        if best_path is None:
            print("No checkpoint found, skipping evaluation")
            break
        preds_inv, trues_inv, attn = evaluate_model_checkpoint(best_path,
                                                                TransformerForecaster if args.model=='transformer' else LSTMForecaster,
                                                                scaler, test_series, input_len, h, device, batch_size=args.batch_size)
        N, C, H = preds_inv.shape
        # compute metrics for first and mean across horizon
        for metric_h in [1, H]:
            preds_slice = preds_inv[:, 0, metric_h-1]
            trues_slice = trues_inv[:, 0, metric_h-1]
            rmse = math.sqrt(mean_squared_error(trues_slice, preds_slice))
            mae = mean_absolute_error(trues_slice, preds_slice)
            mape_v = mape(trues_slice, preds_slice)
            metrics_records.append({'horizon_eval': h, 'step': metric_h, 'rmse': rmse, 'mae': mae, 'mape': mape_v})
        # attention visualizations
        if args.model == 'transformer' and attn:
            os.makedirs(ATTN_DIR, exist_ok=True)
            for i, a in enumerate(attn[:6]):
                try:
                    arr = a.numpy()
                    # handle shapes: (B*num_heads, L, L) or (B, L, L) or (B, num_heads, tgt, src)
                    if arr.ndim == 4:
                        # (B, num_heads, tgt, src) -> average over batch then plot per-head or averaged
                        arrm = arr.mean(axis=0)  # (num_heads, tgt, src)
                        # plot mean over heads
                        arrm_mean = arrm.mean(axis=0)
                    elif arr.ndim == 3:
                        # (B, tgt, src) or (num_heads, tgt, src) depending on version; average properly
                        arrm_mean = arr.mean(axis=0)
                    else:
                        arrm_mean = arr
                    plt.figure(figsize=(6,4))
                    plt.imshow(arrm_mean, aspect='auto')
                    plt.colorbar()
                    plt.title(f'Attention map sample {i}')
                    fname = os.path.join(ATTN_DIR, f'attn_map_{i}.png')
                    plt.savefig(fname)
                    plt.close()
                except Exception as e:
                    print("Failed to plot attention", e)

    pd.DataFrame(metrics_records).to_csv(os.path.join(OUT_DIR, f'metrics_{args.model}.csv'), index=False)
    print("Saved metrics to", os.path.join(OUT_DIR, f'metrics_{args.model}.csv'))

    print("Experiment completed. Brief images referenced at:", BRIEF_IMAGE_1, BRIEF_IMAGE_2)

# ---------------------------
# Optuna HP-search (simple integrated function)
# ---------------------------
def run_hp_search(args):
    try:
        import optuna
    except Exception:
        raise ImportError("Optuna not installed. Generate requirements and install (run run_all.sh).")

    def objective(trial):
        # load data
        df = pd.read_csv(args.data_path)
        series = df['y'].values.astype(float)
        T = len(series)
        train_end = int(T * 0.7)
        val_end = int(T * 0.85)
        train_series = series[:train_end]
        val_series = series[train_end:val_end]
        scaler = StandardScaler().fit(train_series.reshape(-1,1))
        train_scaled = scaler.transform(train_series.reshape(-1,1)).flatten()
        val_scaled = scaler.transform(val_series.reshape(-1,1)).flatten()
        # sample hyperparams
        d_model = trial.suggest_categorical('d_model', [64, 128, 256])
        n_heads = trial.suggest_categorical('n_heads', [2, 4, 8])
        n_layers = trial.suggest_int('n_layers', 1, 4)
        lr = trial.suggest_loguniform('lr', 1e-5, 1e-3)
        d_ff = d_model * 2
        model = TransformerForecaster(input_dim=1, d_model=d_model, n_heads=n_heads, n_layers=n_layers, d_ff=d_ff, horizon=args.horizon, input_len=args.input_len)
        device = torch.device('cuda' if torch.cuda.is_available() and not args.no_cuda else 'cpu')
        model = model.to(device)
        train_ds = TimeSeriesWindowDataset(train_scaled, args.input_len, args.horizon)
        val_ds = TimeSeriesWindowDataset(val_scaled, args.input_len, args.horizon)
        train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, collate_fn=collate_windows)
        val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_windows)
        opt = torch.optim.Adam(model.parameters(), lr=lr)
        loss_fn = nn.MSELoss()
        best_val = float('inf')
        for epoch in range(1, args.epochs+1):
            train_one_epoch(model, train_loader, opt, loss_fn, device)
            val_loss = validate(model, val_loader, loss_fn, device)
            trial.report(val_loss, epoch)
            if trial.should_prune():
                raise optuna.exceptions.TrialPruned()
            best_val = min(best_val, val_loss)
        return best_val

    study = optuna.create_study(direction='minimize')
    study.optimize(objective, n_trials=args.trials)
    print("Best params:", study.best_trial.params)
    os.makedirs(EXPERIMENTS_DIR, exist_ok=True)
    pd.DataFrame([study.best_trial.params]).to_csv(os.path.join(EXPERIMENTS_DIR, 'best_params.csv'), index=False)

# ---------------------------
# Helpers: write requirements, run_all.sh, tests, multi-seed script
# ---------------------------
def write_requirements():
    txt = """numpy
pandas
matplotlib
scikit-learn
torch
optuna
prophet
"""
    with open("requirements.txt", "w") as f:
        f.write(txt)
    print("Wrote requirements.txt")

def write_run_all():
    content = """#!/usr/bin/env bash
set -e
python3 -m venv venv
source venv/bin/activate
pip install -r requirements.txt
# generate synthetic dataset
python3 advanced_ts_transformer_complete.py --generate --length 15000
# quick train (reduced epochs for quick run)
python3 advanced_ts_transformer_complete.py --generate --length 15000 --epochs 4 --input-len 256 --horizon 50
# run a short hp-search (optional, takes time)
python3 advanced_ts_transformer_complete.py --hp-search --trials 3 --epochs 5
# run multi-seed eval
python3 advanced_ts_transformer_complete.py --multi-seed --seeds 3
ls -la results
"""
    with open("run_all.sh", "w") as f:
        f.write(content)
    os.chmod("run_all.sh", 0o755)
    print("Wrote run_all.sh")

def write_tests_and_scripts():
    # tests/test_model_shapes.py
    test_content = '''import torch
from advanced_ts_transformer_complete import TransformerForecaster, LSTMForecaster

def test_transformer_shapes():
    B, C, L = 2, 1, 128
    x = torch.randn(B, C, L)
    m = TransformerForecaster(input_dim=1, d_model=32, n_heads=4, n_layers=2, d_ff=64, horizon=10, input_len=L)
    y, attn = m(x, return_attn=True)
    assert y.shape == (B, 1, 10)

def test_lstm_shapes():
    B, C, L = 2, 1, 128
    x = torch.randn(B, C, L)
    m = LSTMForecaster(input_dim=1, hidden=32, n_layers=1, horizon=10)
    y = m(x)
    assert y.shape == (B, 1, 10)
'''
    os.makedirs(TESTS_DIR, exist_ok=True)
    with open(os.path.join(TESTS_DIR, "test_model_shapes.py"), "w") as f:
        f.write(test_content)

    # multi_seed_eval.py (short wrapper)
    multi_content = '''import argparse, subprocess, pandas as pd
def run_seed(seed):
    cmd = f"python3 advanced_ts_transformer_complete.py --generate --length 15000 --epochs 4 --seed {seed}"
    print("Running:", cmd)
    subprocess.run(cmd, shell=True, check=True)

if __name__ == "__main__":
    p = argparse.ArgumentParser()
    p.add_argument("--seeds", type=int, default=3)
    args = p.parse_args()
    metrics = []
    for s in range(args.seeds):
        run_seed(42 + s)
        df = pd.read_csv("results/metrics_transformer.csv")
        metrics.append(df)
    combined = pd.concat(metrics)
    agg = combined.groupby(["horizon_eval","step"]).agg({"rmse":["mean","std"], "mae":["mean","std"], "mape":["mean","std"]})
    print(agg)
    agg.to_csv("results/aggregate_metrics.csv")
'''
    with open("multi_seed_eval.py", "w") as f:
        f.write(multi_content)

    print("Wrote tests and multi_seed_eval.py")

# ---------------------------
# CLI
# ---------------------------
def parse_args():
    p = argparse.ArgumentParser(description="Complete single-file project for Transformer time series forecasting")
    p.add_argument('--generate', action='store_true', help='Generate synthetic dataset')
    p.add_argument('--length', type=int, default=15000, help='Length of synthetic series to generate')
    p.add_argument('--freq', type=str, default='H', help='Frequency for datetime index')
    p.add_argument('--data-path', type=str, default=os.path.join(DATA_DIR, 'synthetic_15000.csv'))
    p.add_argument('--model', type=str, default='transformer', choices=['transformer','lstm'])
    p.add_argument('--input-len', type=int, dest='input_len', default=256)
    p.add_argument('--horizon', type=int, default=50)
    p.add_argument('--batch-size', type=int, dest='batch_size', default=64)
    p.add_argument('--d-model', type=int, dest='d_model', default=128)
    p.add_argument('--n-heads', type=int, dest='n_heads', default=4)
    p.add_argument('--n-layers', type=int, dest='n_layers', default=3)
    p.add_argument('--d-ff', type=int, dest='d_ff', default=256)
    p.add_argument('--dropout', type=float, default=0.1)
    p.add_argument('--hidden', type=int, default=128)
    p.add_argument('--lr', type=float, default=1e-4)
    p.add_argument('--epochs', type=int, default=30)
    p.add_argument('--patience', type=int, default=6)
    p.add_argument('--seed', type=int, default=42)
    p.add_argument('--no-cuda', action='store_true')
    # special actions
    p.add_argument('--hp-search', action='store_true', help='Run integrated Optuna HP search')
    p.add_argument('--trials', type=int, default=20, help='HP search trials')
    p.add_argument('--multi-seed', action='store_true', help='Run multi-seed evaluation wrapper (reads/writes results files)')
    p.add_argument('--seeds', type=int, default=3)
    p.add_argument('--write-helpers', action='store_true', help='Write requirements, run_all.sh, tests and multi-seed script')
    return p.parse_args()

# ---------------------------
# Main
# ---------------------------
if __name__ == "__main__":
    args = parse_args()

    if args.write_helpers:
        write_requirements()
        write_run_all()
        write_tests_and_scripts()
        print("Helpers written. Run ./run_all.sh to install and run a quick pipeline.")
        sys.exit(0)

    if args.hp_search:
        run_hp_search(args)
        sys.exit(0)

    if args.multi_seed:
        # call the generated script if exists; otherwise run simple repeats
        if os.path.exists("multi_seed_eval.py"):
            subprocess.run(["python3", "multi_seed_eval.py", "--seeds", str(args.seeds)], check=True)
        else:
            # fallback simple multi-run
            all_metrics = []
            for s in range(args.seeds):
                args.seed = 42 + s
                run_experiment(args)
                df = pd.read_csv(os.path.join(OUT_DIR, f"metrics_{args.model}.csv"))
                all_metrics.append(df)
            combined = pd.concat(all_metrics)
            agg = combined.groupby(["horizon_eval","step"]).agg({"rmse":["mean","std"], "mae":["mean","std"], "mape":["mean","std"]})
            agg.to_csv(os.path.join(OUT_DIR, "aggregate_metrics.csv"))
            print("Wrote aggregate_metrics.csv")
        sys.exit(0)

    # default: run experiment
    run_experiment(args)
