In [1]:
from transformers import TrainingArguments
from datasets import DatasetDict

from data_loader import DataLoader, DataLoaderConfig
from features import FeatureEngineer, FeatureConfig
from preprocessing import Preprocessor, PreprocessorConfig
from peft_config import PEFTConfig
from train import train
from evaluate import evaluate, EvalConfig
from predict import PredictConfig, predict_next

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# 1. Load & preprocess data once
dl_cfg = DataLoaderConfig(
    csv_path="../data/BTCUSDT_1d.csv", symbol="BTCUSDT", timeframe="1d"
)
raw_ds: DatasetDict = DataLoader(dl_cfg).load()

In [3]:
print("Raw dataset structure:")
print(raw_ds)
print(f"\nSample data from train split:")
print(raw_ds["train"][0])

Raw dataset structure:
DatasetDict({
    train: Dataset({
        features: ['timestamp', 'open', 'high', 'low', 'close', 'volume', 'symbol', 'timeframe', 'exchange'],
        num_rows: 2303
    })
    validation: Dataset({
        features: ['timestamp', 'open', 'high', 'low', 'close', 'volume', 'symbol', 'timeframe', 'exchange'],
        num_rows: 288
    })
    test: Dataset({
        features: ['timestamp', 'open', 'high', 'low', 'close', 'volume', 'symbol', 'timeframe', 'exchange'],
        num_rows: 288
    })
})

Sample data from train split:
{'timestamp': '2017-08-17 00:00:00+00:00', 'open': 4261.48, 'high': 4485.39, 'low': 4200.74, 'close': 4285.08, 'volume': 795.150377, 'symbol': 'BTCUSDT', 'timeframe': '1d', 'exchange': None}


In [4]:
# Test feature engineering on a single split first
raw_train_ds = raw_ds["train"]
print(f"Raw train dataset columns: {raw_train_ds.column_names}")
print(f"Raw train dataset size: {len(raw_train_ds)}")

Raw train dataset columns: ['timestamp', 'open', 'high', 'low', 'close', 'volume', 'symbol', 'timeframe', 'exchange']
Raw train dataset size: 2303


In [5]:
feat_cfg = FeatureConfig()  # defaults
fe = FeatureEngineer(feat_cfg)

featured_train_ds = fe.transform(raw_train_ds)
featured_train_ds

Columns in dataset before feature engineering: ['timestamp', 'open', 'high', 'low', 'close', 'volume', 'symbol', 'timeframe', 'exchange']
Dropped 33 rows with NaN values from technical indicators
Dropped 33 rows with NaN values from technical indicators


Dataset({
    features: ['close', 'timestamp', 'log_return', 'rsi', 'macd', 'macd_signal', 'hour', 'dayofweek'],
    num_rows: 2270
})

In [6]:
# Debug: Check what happens step by step
import pandas as pd
import numpy as np
from ta.momentum import RSIIndicator
from ta.trend import MACD

# Test with just a few rows
sample_data = raw_train_ds.select(range(min(100, len(raw_train_ds))))
df = pd.DataFrame(
    {"timestamp": sample_data["timestamp"], "close": sample_data["close"]}
)
df["timestamp"] = pd.to_datetime(df["timestamp"])

print("Original DataFrame shape:", df.shape)
print("Close values sample:", df["close"].head())

# Test log return
df["log_return"] = np.log(df["close"] / df["close"].shift(1))
print("After log return - NaN count:", df["log_return"].isna().sum())

# Test RSI
rsi = RSIIndicator(df["close"], window=14).rsi()
df["rsi"] = rsi
print("After RSI - NaN count:", df["rsi"].isna().sum())

# Test MACD
macd = MACD(df["close"], window_slow=26, window_fast=12, window_sign=9)
df["macd"] = macd.macd()
df["macd_signal"] = macd.macd_signal()
print("After MACD - NaN count macd:", df["macd"].isna().sum())
print("After MACD - NaN count signal:", df["macd_signal"].isna().sum())

print("Total rows with any NaN:", df.isna().any(axis=1).sum())
print("Rows remaining after dropna:", len(df.dropna()))

Original DataFrame shape: (100, 2)
Close values sample: 0    4285.08
1    4108.37
2    4139.98
3    4086.29
4    4016.00
Name: close, dtype: float64
After log return - NaN count: 1
After RSI - NaN count: 13
After MACD - NaN count macd: 25
After MACD - NaN count signal: 33
Total rows with any NaN: 33
Rows remaining after dropna: 67


In [7]:
# Apply to all splits
ds_feat = {split: fe.transform(raw_ds[split]) for split in raw_ds}

Columns in dataset before feature engineering: ['timestamp', 'open', 'high', 'low', 'close', 'volume', 'symbol', 'timeframe', 'exchange']
Dropped 33 rows with NaN values from technical indicators
Columns in dataset before feature engineering: ['timestamp', 'open', 'high', 'low', 'close', 'volume', 'symbol', 'timeframe', 'exchange']
Dropped 33 rows with NaN values from technical indicators
Dropped 33 rows with NaN values from technical indicators
Columns in dataset before feature engineering: ['timestamp', 'open', 'high', 'low', 'close', 'volume', 'symbol', 'timeframe', 'exchange']
Dropped 33 rows with NaN values from technical indicators
Columns in dataset before feature engineering: ['timestamp', 'open', 'high', 'low', 'close', 'volume', 'symbol', 'timeframe', 'exchange']
Dropped 33 rows with NaN values from technical indicators
Columns in dataset before feature engineering: ['timestamp', 'open', 'high', 'low', 'close', 'volume', 'symbol', 'timeframe', 'exchange']
Dropped 33 rows with

In [8]:
print("Dataset after feature engineering:")
print(DatasetDict(ds_feat))
print(f"\nSample processed data:")
print(ds_feat["train"][0])
print(f"\nFeature columns: {ds_feat['train'].column_names}")

Dataset after feature engineering:
DatasetDict({
    train: Dataset({
        features: ['close', 'timestamp', 'log_return', 'rsi', 'macd', 'macd_signal', 'hour', 'dayofweek'],
        num_rows: 2270
    })
    validation: Dataset({
        features: ['close', 'timestamp', 'log_return', 'rsi', 'macd', 'macd_signal', 'hour', 'dayofweek'],
        num_rows: 255
    })
    test: Dataset({
        features: ['close', 'timestamp', 'log_return', 'rsi', 'macd', 'macd_signal', 'hour', 'dayofweek'],
        num_rows: 255
    })
})

Sample processed data:
{'close': 3910.04, 'timestamp': datetime.datetime(2017, 9, 19, 0, 0, tzinfo=<UTC>), 'log_return': -0.0314611759096125, 'rsi': 45.858821173501354, 'macd': -148.39876362737823, 'macd_signal': -119.00047689320108, 'hour': 0, 'dayofweek': 1}

Feature columns: ['close', 'timestamp', 'log_return', 'rsi', 'macd', 'macd_signal', 'hour', 'dayofweek']


In [9]:
# Continue with preprocessing to create sliding windows
prep_cfg = PreprocessorConfig(context_length=128, prediction_length=24, stride=1)
prep = Preprocessor(prep_cfg)
ds_windows = DatasetDict({split: prep.transform(ds_feat[split]) for split in ds_feat})

print("Dataset after windowing:")
print(ds_windows)
print(f"\nSample window:")
if len(ds_windows["train"]) > 0:
    sample_window = ds_windows["train"][0]
    print(f"Past values shape: {len(sample_window['past_values'])}")
    print(f"Future values shape: {len(sample_window['future_values'])}")
    print(f"Past values sample: {sample_window['past_values'][:5]}")
    print(f"Future values sample: {sample_window['future_values'][:5]}")
else:
    print("No training windows available!")

Map: 100%|██████████| 2270/2270 [00:00<00:00, 107528.04 examples/s]
Map: 100%|██████████| 255/255 [00:00<00:00, 77134.54 examples/s]
Map: 100%|██████████| 2270/2270 [00:00<00:00, 107528.04 examples/s]
Map: 100%|██████████| 255/255 [00:00<00:00, 77134.54 examples/s]
Map: 100%|██████████| 255/255 [00:00<00:00, 63663.54 examples/s]

Dataset after windowing:
DatasetDict({
    train: Dataset({
        features: ['past_values', 'future_values'],
        num_rows: 2119
    })
    validation: Dataset({
        features: ['past_values', 'future_values'],
        num_rows: 104
    })
    test: Dataset({
        features: ['past_values', 'future_values'],
        num_rows: 104
    })
})

Sample window:
Past values shape: 128
Future values shape: 24
Past values sample: [3910.04, 3900.0, 3609.99, 3595.87, 3780.0]
Future values sample: [11175.27, 11089.0, 11491.0, 11879.95, 11251.0]





In [10]:
# Test training with custom trainer
import os

os.environ["WANDB_DISABLED"] = "true"
os.environ["WANDB_MODE"] = "disabled"

from transformers import TrainingArguments
from train import train
from peft_config import PEFTConfig
from model_wrappers import ModelFactory

# Model config for PatchTST
model_kwargs = {
    "context_length": 128,
    "prediction_length": 24,
    "num_input_channels": 1,
}

# Create a test model to inspect its structure
test_model = ModelFactory.create("patchtst", **model_kwargs)
print("Model structure:")
for name, module in test_model.named_modules():
    if any(
        target in name.lower()
        for target in ["linear", "attention", "query", "key", "value", "proj"]
    ):
        print(f"  {name}: {type(module).__name__}")

print("\nAll module names:")
module_names = [name for name, _ in test_model.named_modules()]
for name in module_names[:20]:  # Show first 20 modules
    print(f"  {name}")
print("... (truncated)")

# Training config
training_args = TrainingArguments(
    output_dir="./test_output",
    num_train_epochs=1,  # Just 1 epoch for testing
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    learning_rate=1e-4,
    logging_steps=10,
    eval_strategy="steps",  # Fixed parameter name
    eval_steps=50,
    save_steps=100,
    report_to=None,  # Disable wandb
    remove_unused_columns=False,
    dataloader_pin_memory=False,
)

# PEFT config
peft_cfg = PEFTConfig(
    peft_method="lora",
    lora_r=8,
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules=["query", "value", "key"],
)

print("Starting training test...")
trainer = train(
    model_key="patchtst",
    model_kwargs=model_kwargs,
    peft_cfg=peft_cfg,
    datasets=ds_windows,
    output_dir="./test_output",
    training_args=training_args,
)

Model structure:
  encoder.layers.0.self_attn.k_proj: Linear
  encoder.layers.0.self_attn.v_proj: Linear
  encoder.layers.0.self_attn.q_proj: Linear
  encoder.layers.0.self_attn.out_proj: Linear
  encoder.layers.1.self_attn.k_proj: Linear
  encoder.layers.1.self_attn.v_proj: Linear
  encoder.layers.1.self_attn.q_proj: Linear
  encoder.layers.1.self_attn.out_proj: Linear
  encoder.layers.2.self_attn.k_proj: Linear
  encoder.layers.2.self_attn.v_proj: Linear
  encoder.layers.2.self_attn.q_proj: Linear
  encoder.layers.2.self_attn.out_proj: Linear

All module names:
  
  scaler
  scaler.scaler
  patchifier
  masking
  encoder
  encoder.embedder
  encoder.embedder.input_embedding
  encoder.positional_encoder
  encoder.positional_encoder.positional_dropout
  encoder.layers
  encoder.layers.0
  encoder.layers.0.self_attn
  encoder.layers.0.self_attn.k_proj
  encoder.layers.0.self_attn.v_proj
  encoder.layers.0.self_attn.q_proj
  encoder.layers.0.self_attn.out_proj
  encoder.layers.0.dropout_

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Starting training test...
Starting training for model: patchtst
Output directory: ./test_output
Base model created: PatchTSTModel
Building PEFT model with method: lora
Output directory: ./test_output
Base model created: PatchTSTModel
Building PEFT model with method: lora


ValueError: Target modules {'key', 'query', 'value'} not found in the base model. Please check the target modules and try again.

In [None]:
# Now test training with correct target modules
import os

os.environ["WANDB_DISABLED"] = "true"
os.environ["WANDB_MODE"] = "disabled"

from transformers import TrainingArguments
from train import train
from peft_config import PEFTConfig

# Training config
training_args = TrainingArguments(
    output_dir="./test_output",
    num_train_epochs=1,  # Just 1 epoch for testing
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    learning_rate=1e-4,
    logging_steps=10,
    eval_strategy="steps",  # Fixed parameter name
    eval_steps=50,
    save_steps=100,
    report_to=None,  # Disable wandb
    remove_unused_columns=False,
    dataloader_pin_memory=False,
)

# PEFT config with correct target modules for PatchTST
peft_cfg = PEFTConfig(
    peft_method="lora",
    lora_r=8,
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "out_proj",
    ],  # Correct modules for PatchTST
)

# Model config for PatchTST
model_kwargs = {
    "context_length": 128,
    "prediction_length": 24,
    "num_input_channels": 1,
}

print("Starting training test with correct target modules...")
trainer = train(
    model_key="patchtst",
    model_kwargs=model_kwargs,
    peft_cfg=peft_cfg,
    datasets=ds_windows,
    output_dir="./test_output",
    training_args=training_args,
)

In [None]:
# Debug: Check the data shapes being passed
import torch

print("Checking data shapes:")
sample_batch = [ds_windows["train"][i] for i in range(4)]
print(f"Sample batch structure:")
for i, example in enumerate(sample_batch):
    print(f"  Example {i}:")
    print(f"    past_values type: {type(example['past_values'])}")
    print(
        f"    past_values shape/len: {len(example['past_values']) if hasattr(example['past_values'], '__len__') else 'N/A'}"
    )
    print(f"    future_values type: {type(example['future_values'])}")
    print(
        f"    future_values shape/len: {len(example['future_values']) if hasattr(example['future_values'], '__len__') else 'N/A'}"
    )
    if i == 0:  # Show actual values for first example
        print(f"    past_values[:5]: {example['past_values'][:5]}")
        print(f"    future_values[:5]: {example['future_values'][:5]}")
    print()


# Test collate function manually
def test_collate_fn(batch):
    """Test version of collate function"""
    print(f"Input batch length: {len(batch)}")

    # Extract past_values and future_values from the batch
    past_values = torch.tensor(
        [example["past_values"] for example in batch], dtype=torch.float32
    )
    future_values = torch.tensor(
        [example["future_values"] for example in batch], dtype=torch.float32
    )

    print(f"Raw past_values tensor shape: {past_values.shape}")
    print(f"Raw future_values tensor shape: {future_values.shape}")

    # Add feature dimension: (batch_size, sequence_length) -> (batch_size, sequence_length, num_features)
    past_values = past_values.unsqueeze(-1)  # Add feature dimension
    future_values = future_values.unsqueeze(-1)  # Add feature dimension

    print(f"Final past_values tensor shape: {past_values.shape}")
    print(f"Final future_values tensor shape: {future_values.shape}")

    return {
        "past_values": past_values,
        "future_values": future_values,
    }


print("Testing collate_fn:")
try:
    batch_result = test_collate_fn(sample_batch)
    print(f"Collated batch keys: {batch_result.keys()}")
    for key, tensor in batch_result.items():
        print(f"  {key} shape: {tensor.shape}")
        print(f"  {key} dtype: {tensor.dtype}")
except Exception as e:
    print(f"Error in collate_fn: {e}")
    import traceback

    traceback.print_exc()

In [None]:
# Test different tensor formats for PatchTST
sample_batch = [
    ds_windows["train"][i] for i in range(2)
]  # Use smaller batch for testing


def test_model_input_format(batch):
    """Test different tensor formats to see which one works"""
    past_values = torch.tensor(
        [example["past_values"] for example in batch], dtype=torch.float32
    )

    print("Testing different input formats:")

    # Format 1: (batch_size, sequence_length, num_features)
    format1 = past_values.unsqueeze(-1)
    print(f"Format 1 (B, S, F): {format1.shape}")

    # Format 2: (batch_size, num_features, sequence_length)
    format2 = past_values.unsqueeze(1)  # Add channel dimension at index 1
    print(f"Format 2 (B, F, S): {format2.shape}")

    # Test with the actual model
    from model_wrappers import ModelFactory

    model_kwargs = {
        "context_length": 128,
        "prediction_length": 24,
        "num_input_channels": 1,
    }
    model = ModelFactory.create("patchtst", **model_kwargs)

    print("\\nTesting formats with model:")

    try:
        print("Testing format 1 (B, S, F)...")
        output1 = model(past_values=format1)
        print(f"Format 1 SUCCESS! Output type: {type(output1)}")
    except Exception as e:
        print(f"Format 1 FAILED: {e}")

    try:
        print("Testing format 2 (B, F, S)...")
        output2 = model(past_values=format2)
        print(f"Format 2 SUCCESS! Output type: {type(output2)}")
    except Exception as e:
        print(f"Format 2 FAILED: {e}")

    return format2  # Return the working format


correct_format = test_model_input_format(sample_batch)

In [None]:
# Inspect actual PatchTST model output
from model_wrappers import ModelFactory

# Create test data
sample_batch = [ds_windows["train"][i] for i in range(2)]
past_values = torch.tensor(
    [example["past_values"] for example in sample_batch], dtype=torch.float32
)
future_values = torch.tensor(
    [example["future_values"] for example in sample_batch], dtype=torch.float32
)

# Correct format: (batch_size, sequence_length, num_features)
past_values = past_values.unsqueeze(-1)
future_values = future_values.unsqueeze(-1)

print(f"Input past_values shape: {past_values.shape}")
print(f"Target future_values shape: {future_values.shape}")

# Create model
model_kwargs = {
    "context_length": 128,
    "prediction_length": 24,
    "num_input_channels": 1,
}
model = ModelFactory.create("patchtst", **model_kwargs)

# Get model output
with torch.no_grad():
    outputs = model(past_values=past_values)

print(f"\\nModel output type: {type(outputs)}")
print(f"Model output attributes: {dir(outputs)}")

if hasattr(outputs, "__dict__"):
    print(f"Model output dict: {outputs.__dict__}")

# Check if it has specific attributes
for attr in ["prediction_outputs", "forecast", "last_hidden_state", "decoder_output"]:
    if hasattr(outputs, attr):
        val = getattr(outputs, attr)
        print(
            f"{attr}: {type(val)}, shape: {val.shape if hasattr(val, 'shape') else 'N/A'}"
        )

# Check if it's iterable
try:
    if isinstance(outputs, (tuple, list)):
        print(f"\\nOutput is a {type(outputs)} with {len(outputs)} elements:")
        for i, item in enumerate(outputs):
            print(
                f"  Element {i}: type={type(item)}, shape={item.shape if hasattr(item, 'shape') else 'N/A'}"
            )
    else:
        print(f"\\nOutput is not a tuple/list")
except Exception as e:
    print(f"Error checking output structure: {e}")

# Try to extract a prediction tensor
try:
    if hasattr(outputs, "prediction_outputs"):
        pred = outputs.prediction_outputs
    elif hasattr(outputs, "forecast"):
        pred = outputs.forecast
    elif hasattr(outputs, "last_hidden_state"):
        pred = outputs.last_hidden_state
    else:
        pred = outputs

    print(f"\\nExtracted prediction shape: {pred.shape}")
    print(f"Prediction tensor dtype: {pred.dtype}")

    # Test loss computation
    from torch.nn import MSELoss

    loss_fn = MSELoss()

    # Try direct loss computation
    try:
        loss = loss_fn(pred, future_values)
        print(f"Direct loss computation SUCCESS: {loss.item()}")
    except Exception as e:
        print(f"Direct loss computation FAILED: {e}")

        # Try with shape adjustment
        if len(pred.shape) == 4:
            pred_adjusted = pred.view(pred.shape[0], -1, pred.shape[-1])
            print(f"Adjusted prediction shape: {pred_adjusted.shape}")
            try:
                loss = loss_fn(
                    pred_adjusted[:, : future_values.shape[1], :], future_values
                )
                print(f"Adjusted loss computation SUCCESS: {loss.item()}")
            except Exception as e2:
                print(f"Adjusted loss computation FAILED: {e2}")

except Exception as e:
    print(f"Error extracting prediction: {e}")

In [None]:
# Test training with updated custom trainer
import os

os.environ["WANDB_DISABLED"] = "true"
os.environ["WANDB_MODE"] = "disabled"

from transformers import TrainingArguments
from train import train
from peft_config import PEFTConfig

# Training config - very minimal for testing
training_args = TrainingArguments(
    output_dir="./test_output",
    num_train_epochs=1,
    per_device_train_batch_size=2,  # Smaller batch for testing
    per_device_eval_batch_size=2,
    learning_rate=1e-4,
    logging_steps=5,
    eval_strategy="steps",
    eval_steps=20,
    save_steps=50,
    report_to=None,
    remove_unused_columns=False,
    dataloader_pin_memory=False,
    max_steps=10,  # Just run a few steps for testing
)

# PEFT config with correct target modules for PatchTST
peft_cfg = PEFTConfig(
    peft_method="lora",
    lora_r=8,
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)

# Model config for PatchTST
model_kwargs = {
    "context_length": 128,
    "prediction_length": 24,
    "num_input_channels": 1,
}

print("Starting training test with updated trainer...")
try:
    trainer = train(
        model_key="patchtst",
        model_kwargs=model_kwargs,
        peft_cfg=peft_cfg,
        datasets=ds_windows,
        output_dir="./test_output",
        training_args=training_args,
    )
    print("Training completed successfully!")
    print(f"Final training state: {trainer.state}")
except Exception as e:
    print(f"Training failed with error: {e}")
    import traceback

    traceback.print_exc()

In [None]:
# Test our custom trainer step-by-step
import importlib
import train

importlib.reload(train)

from model_wrappers import ModelFactory
from peft_config import PEFTConfig
import torch


# Define collate function directly for testing
def test_collate_fn(batch):
    """Test version of collate function"""
    print(f"Collating batch of size: {len(batch)}")

    # Extract past_values and future_values from the batch
    past_values = torch.tensor(
        [example["past_values"] for example in batch], dtype=torch.float32
    )
    future_values = torch.tensor(
        [example["future_values"] for example in batch], dtype=torch.float32
    )

    print(
        f"Raw tensors - past_values: {past_values.shape}, future_values: {future_values.shape}"
    )

    # Add feature dimension: (batch_size, sequence_length) -> (batch_size, sequence_length, num_features)
    past_values = past_values.unsqueeze(-1)  # Add feature dimension
    future_values = future_values.unsqueeze(-1)  # Add feature dimension

    print(
        f"Final tensors - past_values: {past_values.shape}, future_values: {future_values.shape}"
    )

    return {
        "past_values": past_values,
        "future_values": future_values,
    }


print("=== Testing collate function ===")
sample_batch = [ds_windows["train"][i] for i in range(2)]
collated = test_collate_fn(sample_batch)

# Test model creation
print("\\n=== Testing model creation ===")
model_kwargs = {
    "context_length": 128,
    "prediction_length": 24,
    "num_input_channels": 1,
}
base_model = ModelFactory.create("patchtst", **model_kwargs)
print(f"Base model created: {type(base_model)}")

# Test PEFT wrapping
print("\\n=== Testing PEFT wrapping ===")
peft_cfg = PEFTConfig(
    peft_method="lora",
    lora_r=8,
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)
peft_model = train.build_peft_model(base_model, peft_cfg)
print(f"PEFT model created: {type(peft_model)}")

# Test direct model call
print("\\n=== Testing direct model call ===")
try:
    with torch.no_grad():
        outputs = peft_model(past_values=collated["past_values"])
    print(f"Model call SUCCESS! Output type: {type(outputs)}")
    if hasattr(outputs, "prediction_outputs"):
        print(f"Prediction outputs shape: {outputs.prediction_outputs.shape}")
except Exception as e:
    print(f"Model call FAILED: {e}")
    import traceback

    traceback.print_exc()

In [None]:
# Test a minimal training step
from transformers import TrainingArguments

print("=== Testing custom trainer compute_loss ===")
training_args = TrainingArguments(
    output_dir="./test_output",
    num_train_epochs=1,
    per_device_train_batch_size=2,
    learning_rate=1e-4,
    report_to=None,
    remove_unused_columns=False,  # Important: don't remove columns
)

# Create minimal trainer
trainer = train.TimeSeriesTrainer(
    model=peft_model,
    args=training_args,
    train_dataset=ds_windows["train"].select(range(4)),  # Very small subset
    eval_dataset=ds_windows["validation"].select(range(4)),
    data_collator=test_collate_fn,
)

print("Trainer created successfully")

# Test compute_loss directly
try:
    print("Testing compute_loss...")
    loss = trainer.compute_loss(peft_model, collated)
    print(f"Compute_loss SUCCESS! Loss: {loss.item()}")
except Exception as e:
    print(f"Compute_loss FAILED: {e}")
    import traceback

    traceback.print_exc()

# Test one training step
print("\\n=== Testing one training step ===")
try:
    # Get a batch from the dataloader
    train_dataloader = trainer.get_train_dataloader()
    batch = next(iter(train_dataloader))
    print(f"Dataloader batch keys: {batch.keys()}")
    for key, tensor in batch.items():
        print(f"  {key}: {tensor.shape}")

    # Test training step
    peft_model.train()
    loss = trainer.training_step(peft_model, batch)
    print(f"Training step SUCCESS! Loss: {loss}")
except Exception as e:
    print(f"Training step FAILED: {e}")
    import traceback

    traceback.print_exc()

In [None]:
# Test training with fixed trainer
import importlib
import train

importlib.reload(train)

import os

os.environ["WANDB_DISABLED"] = "true"
os.environ["WANDB_MODE"] = "disabled"

from transformers import TrainingArguments

# Use CPU for now to avoid device issues
import torch

device = torch.device("cpu")

print("=== Testing fixed training ===")

# Recreate model on CPU
model_kwargs = {
    "context_length": 128,
    "prediction_length": 24,
    "num_input_channels": 1,
}
base_model = ModelFactory.create("patchtst", **model_kwargs)
base_model = base_model.to(device)

peft_cfg = PEFTConfig(
    peft_method="lora",
    lora_r=8,
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)
peft_model = train.build_peft_model(base_model, peft_cfg)
peft_model = peft_model.to(device)

print(f"Model device: {next(peft_model.parameters()).device}")

# Training config - very minimal for testing
training_args = TrainingArguments(
    output_dir="./test_output",
    num_train_epochs=1,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    learning_rate=1e-4,
    logging_steps=2,
    eval_strategy="steps",
    eval_steps=5,
    save_steps=10,
    report_to=None,
    remove_unused_columns=False,
    dataloader_pin_memory=False,
    max_steps=5,  # Just run a few steps for testing
)

try:
    trainer = train.train(
        model_key="patchtst",
        model_kwargs=model_kwargs,
        peft_cfg=peft_cfg,
        datasets={
            "train": ds_windows["train"].select(range(10)),
            "validation": ds_windows["validation"].select(range(5)),
        },
        output_dir="./test_output",
        training_args=training_args,
    )
    print("\\n🎉 TRAINING COMPLETED SUCCESSFULLY! 🎉")
    print(f"Final training state: {trainer.state.global_step} steps completed")
    print(
        f"Final loss: {trainer.state.log_history[-1] if trainer.state.log_history else 'N/A'}"
    )

except Exception as e:
    print(f"Training failed with error: {e}")
    import traceback

    traceback.print_exc()