In [None]:
import pandas as pd
label = pd.read_csv("/home/iatell/projects/meta-learning/data/seq_line_labels.csv")
label["seq_len"] = label["endIndex"] - label["startIndex"]
label

# model


In [1]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
from torch.nn.utils.rnn import pack_padded_sequence
import torch.nn.functional as F
def mdn_split_params(raw_params, n_components):
    """
    raw_params: (B, T, 3K) tensor from mdn_head
    returns:
        pi    (B, T, K)  mixture weights
        mu    (B, T, K)  means
        sigma (B, T, K)  std devs
    """
    B, T, threeK = raw_params.shape
    assert threeK == 3 * n_components
    raw = raw_params.view(B, T, n_components, 3)

    pi = raw[..., 0]                 # (B,T,K)
    mu = raw[..., 1]                 # (B,T,K)
    sigma = raw[..., 2]              # (B,T,K)

    pi = F.softmax(pi, dim=-1)       # weights sum to 1
    sigma = F.softplus(sigma) + 1e-4 # strictly positive
    return pi, mu, sigma


def mdn_nll(y, pi, mu, sigma, mask=None):
    """
    Negative log-likelihood loss for 1D MDN.

    Args:
        y     : (B, T) or (B, T, 1) targets
        pi    : (B, T, K) mixture weights
        mu    : (B, T, K) means
        sigma : (B, T, K) std devs
        mask  : (B, T) optional mask for valid timesteps

    Returns:
        scalar loss (mean over batch and time)
    """
    B, T, K = pi.shape

    # ensure shape (B,T,1)
    if y.dim() == 2:  
        y = y.unsqueeze(-1)  

    # crop labels if longer than model output
    if y.size(1) > T:
        y = y[:, :T]

    y = y.expand(-1, -1, K)  # (B,T,K)

    # log probability of each Gaussian
    log_prob = -0.5 * ((y - mu) / (sigma + 1e-8))**2 \
               - torch.log(sigma + 1e-8) \
               - 0.5 * torch.log(torch.tensor(2.0 * torch.pi, device=y.device))

    # log-sum-exp over mixtures
    log_mix = torch.log(pi + 1e-8) + log_prob
    log_sum = torch.logsumexp(log_mix, dim=-1)  # (B,T)

    if mask is not None:
        log_sum = log_sum * mask.float()
        denom = mask.float().sum().clamp_min(1.0)
        nll = -log_sum.sum() / denom
    else:
        nll = -log_sum.mean()
    return nll

class CNNLSTM_MDN(pl.LightningModule):
    def __init__(self, input_dim, max_len_y, hidden_dim=128, num_layers=1,
                 lr=1e-3, n_components=5, cnn_channels=64, dropout=0.1):
        super().__init__()
        self.save_hyperparameters()

        # CNN
        self.conv1 = nn.Conv1d(input_dim, cnn_channels, kernel_size=1)
        self.conv3 = nn.Conv1d(input_dim, cnn_channels, kernel_size=3, padding=1)
        self.bn1   = nn.BatchNorm1d(cnn_channels)
        self.bn3   = nn.BatchNorm1d(cnn_channels)

        # LSTM
        fused_dim = 2 * cnn_channels
        self.lstm = nn.LSTM(fused_dim, hidden_dim, num_layers=num_layers,
                            batch_first=True, dropout=dropout if num_layers > 1 else 0)

        # MDN output (only branch)
        self.mdn_head = nn.Linear(hidden_dim, 3 * n_components)
        self.n_components = n_components
        self.lr = lr
        self.max_len_y = max_len_y

    def forward(self, X, lengths=None):
        # CNN
        x = X.transpose(1, 2)
        x1 = F.relu(self.bn1(self.conv1(x)))
        x3 = F.relu(self.bn3(self.conv3(x)))
        xf = torch.cat([x1, x3], dim=1).transpose(1, 2)

        # LSTM
        h_seq, _ = self.lstm(xf)

        # MDN
        raw = self.mdn_head(h_seq[:, :self.max_len_y])  # (B,T,3K)
        pi, mu, sigma = mdn_split_params(raw, self.n_components)
        return {"pi": pi, "mu": mu, "sigma": sigma}

    def training_step(self, batch, batch_idx):
        X, y_line, lengths = batch
        mdn = self(X, lengths)

        T_out = mdn["pi"].size(1)  # model output length
        y_line = y_line[:, :T_out]  # crop labels to match

        # Mask padded timesteps
        if lengths is not None:
            idx = torch.arange(T_out, device=y_line.device).unsqueeze(0)
            mask = (idx < lengths.unsqueeze(1)).float()
        else:
            mask = None

        loss = mdn_nll(y_line, mdn["pi"], mdn["mu"], mdn["sigma"], mask)
        self.log("train/loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        X, y_line, lengths = batch
        mdn = self(X, lengths)

        T_out = mdn["pi"].size(1)
        y_line = y_line[:, :T_out]

        loss = mdn_nll(y_line, mdn["pi"], mdn["mu"], mdn["sigma"])
        self.log("val/loss", loss, prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)


# train

In [2]:
import sys
from pathlib import Path

# Current notebook location
notebook_path = Path().resolve()

# Add parent folder (meta/) to sys.path
sys.path.append(str(notebook_path.parent))
import joblib
import torch
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report, confusion_matrix
from datetime import datetime
from preprocess.multi_regression_seq_dif import preprocess_sequences_csv_multilines
# from models.LSTM.lstm_multi_line_reg_seq_dif import LSTMMultiRegressor
from utils.print_batch import print_batch
from utils.to_address import to_address
from utils.json_to_csv import json_to_csv_in_memory
from utils.padding_batch_reg import collate_batch
import pandas as pd
import io
import numpy as np
import os

from sklearn.metrics import accuracy_score, f1_score
# ---------------- Evaluation ---------------- #
@torch.no_grad()
def evaluate_model_mdn(model, val_loader, zero_idx=0, threshold=0.7):
    """
    Evaluate CNN–LSTM–MDN model.

    Args
    ----
    model : pl.LightningModule with MDN forward
    val_loader : DataLoader yielding (X, y, lengths)
    zero_idx : which mixture component is considered "no-line" (usually 0)
    threshold : if pi[:,t,zero_idx] > threshold → predict invalid

    Returns
    -------
    dict with mse, mae, acc, f1
    """
    model.eval()
    all_preds_reg, all_labels_reg = [], []
    all_preds_len, all_labels_len = [], []

    device = next(model.parameters()).device

    with torch.no_grad():
        for X_batch, y_batch, lengths in val_loader:
            X_batch = X_batch.to(device)
            y_batch = y_batch.to(device)
            lengths = lengths.to(device)

            # Forward pass
            mdn = model(X_batch, lengths)
            if isinstance(mdn, dict):
                pi, mu, sigma = mdn["pi"], mdn["mu"], mdn["sigma"]
            else:  # old tuple version
                pi, mu, sigma = mdn

            # Crop labels to model output length
            T_out = pi.size(1)
            y_batch = y_batch[:, :T_out]

            # Expected regression output
            y_pred = (pi * mu).sum(dim=-1)  # (B, T)
            all_preds_reg.append(y_pred.cpu().numpy())
            all_labels_reg.append(y_batch.cpu().numpy())

            # Validity from MDN: high weight on zero component
            pi_zero = pi[:, :, zero_idx]  # (B,T)
            pred_valid = (pi_zero < (1 - threshold)).long().cpu().numpy()  # 1 = valid

            # True valid positions from sequence lengths
            true_valid = (torch.arange(T_out, device=device)
                          .unsqueeze(0) < lengths.unsqueeze(1)).long().cpu().numpy()

            all_preds_len.extend(pred_valid.flatten().tolist())
            all_labels_len.extend(true_valid.flatten().tolist())

    # ----- Regression metrics -----
    all_preds_reg = np.vstack(all_preds_reg)  # (total_samples, T)
    all_labels_reg = np.vstack(all_labels_reg)
    mse = ((all_preds_reg - all_labels_reg) ** 2).mean()
    mae = np.abs(all_preds_reg - all_labels_reg).mean()

    # ----- Validity metrics -----
    acc = accuracy_score(all_labels_len, all_preds_len)
    f1 = f1_score(all_labels_len, all_preds_len, average="macro")

    print("\n📊 Validation Metrics (MDN):")
    print(f"  Regression → MSE: {mse:.6f}, MAE: {mae:.6f}")
    print(f"  Validity   → Acc: {acc:.4f}, F1: {f1:.4f}")

    return {"mse": mse, "mae": mae, "acc": acc, "f1": f1}

# ---------------- Train ---------------- #
def train_model(
    data_csv,
    labels_csv,
    model_out_dir="models/saved_models",
    do_validation=True,
    hidden_dim=128,
    num_layers=1,
    lr=0.001,
    batch_size=32,
    max_epochs=200,
    save_model=False,
    return_val_accuracy = True
):

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    model_out = f"{model_out_dir}/lstm_model_multireg_{timestamp}.pt"
    meta_out  = f"{model_out_dir}/lstm_meta_multireg_{timestamp}.pkl"

    # Preprocess: pad linePrices and sequences
    if do_validation:
        train_ds, val_ds, df, feature_cols, max_len_y = preprocess_sequences_csv_multilines(
            data_csv, labels_csv,
            val_split=True,
            for_xgboost=False,
            debug_sample=False
        )
    else:
        train_ds, df, feature_cols, max_len_y = preprocess_sequences_csv_multilines(
            data_csv, labels_csv,
            val_split=False,
            for_xgboost=False,
            debug_sample=False
        )
        val_ds = None

    sample = train_ds[0][0]  # first sample's features
    if isinstance(sample, dict):  # multiple feature groups
        input_dim = sample['main'].shape[1]
    else:  # single tensor
        input_dim = sample.shape[1]

    model = CNNLSTM_MDN(
        input_dim=input_dim,
        hidden_dim=hidden_dim,
        num_layers=num_layers,
        max_len_y=max_len_y,
        lr=lr
    )

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)
    val_loader = DataLoader(val_ds, batch_size=batch_size, collate_fn=collate_batch) if val_ds else None

    trainer = pl.Trainer(max_epochs=max_epochs, accelerator="auto", devices=1)
    trainer.fit(model, train_loader, val_loader)

    if save_model:
        os.makedirs(model_out_dir, exist_ok=True)
        trainer.save_checkpoint(model_out)
        joblib.dump({
            "input_dim": input_dim,
            "hidden_dim": hidden_dim,
            "num_layers": num_layers,
            "max_len_y": max_len_y,
            "feature_cols": feature_cols
        }, meta_out)
        print(f"✅ Model saved to {model_out}")
        print(f"✅ Meta saved to {meta_out}")
    # --- Evaluation --- #
    if do_validation:
        mse, mae, acc, f1 = evaluate_model_mdn(model, val_loader)
        if return_val_accuracy:
            return {"mse": mse, "mae": mae, "acc": acc, "f1": f1}
        
if __name__ == "__main__":
    train_model(
        "/home/iatell/projects/meta-learning/data/Bitcoin_BTCUSDT_kaggle_1D_candles_prop.csv",
        "/home/iatell/projects/meta-learning/data/seq_line_labels.csv",
        save_model=True,
        do_validation=True
    )


💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3060') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
2025-09-03 02:47:46.758518: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-09-03 02:47:46.8

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/iatell/envs/Rllib2.43/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
/home/iatell/envs/Rllib2.43/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
/home/iatell/envs/Rllib2.43/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:310: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=200` reached.


✅ Model saved to models/saved_models/lstm_model_multireg_20250903_024745.pt
✅ Meta saved to models/saved_models/lstm_meta_multireg_20250903_024745.pkl

📊 Validation Metrics (MDN):
  Regression → MSE: 0.115894, MAE: 0.267756
  Validity   → Acc: 0.9000, F1: 0.4737


# server

In [None]:
# server_cnn_mdn.py
import glob
import joblib
import torch
import numpy as np
from flask import Flask, request, jsonify, render_template
# from models.LSTM.cnn_lstm_mdn import CNNLSTM_MDN
# from preprocess import sequential as dp

app = Flask(__name__)

# ---------------- Load model and meta ----------------
meta_path = glob.glob("models/saved_models/lstm_meta_multireg_*.pkl")[0]
state_path = glob.glob("models/saved_models/lstm_model_multireg*.pt")[0]

meta = joblib.load(meta_path)
FEATURES = meta['feature_cols']  # load the same columns used during training

def one_sample(df_slice):

    return df_slice[FEATURES].values.astype(np.float32)
def load_raw_data_serve(candle_csv, label_csv):
    candles = pd.read_csv(candle_csv, parse_dates=['timestamp'])
    labels  = pd.read_csv(label_csv, header=None,
                          names=['timestamp', 'last_close', 'line_raw'])
    labels['timestamp'] = pd.to_datetime(labels['timestamp'])

    # Keep all candles
    df = pd.merge(candles, labels[['timestamp', 'line_raw']], on='timestamp', how='left')

    # Extract line price safely
    df['label_price'] = df['line_raw'].str.extract(r"\(([^)]+)\)")[0].astype(float)
    df['has_line'] = (~df['label_price'].isna()).astype(int)

    # Optional log factor for features
    df['log_factor'] = np.where(df['has_line'],
                                np.log(np.maximum(df['label_price'], 1e-6)),
                                np.nan)

    return df

model = CNNLSTM_MDN(
    input_dim=meta['input_dim'],
    max_len_y=meta.get('max_len_y', 20),   # default to 20 if not saved
    hidden_dim=meta.get('hidden_dim', 128),
    num_layers=meta.get('num_layers', 1),
    n_components=meta.get('n_components', 5)
)
# Load Lightning checkpoint
model = CNNLSTM_MDN.load_from_checkpoint(state_path)
model.eval()

# ---------------- Load data ----------------
df = load_raw_data_serve(
    "/home/iatell/projects/meta-learning/data/Bitcoin_BTCUSDT_kaggle_1D_candles_prop.csv",
    "/home/iatell/projects/meta-learning/data/ohlcv_log(2).csv"
)

# ---------------- Routes ----------------
@app.route("/")
def home():
    return render_template("sequential.html")


@app.route("/candles")
def candles():
    dense = df.set_index('timestamp').asfreq('D').ffill()
    return jsonify([
        {'time': int(ts.timestamp()),
         'open': float(row.open),
         'high': float(row.high),
         'low': float(row.low),
         'close': float(row.close)}
        for ts, row in dense.iterrows()
    ])


@app.route("/predict", methods=['POST'])
def predict():
    data = request.get_json(force=True)
    idxs = data.get('idxs')         # must be list of indices
    seq_len = data.get('seq_len')   # must be int

    if not idxs or not isinstance(idxs, list):
        return jsonify({"error": "Provide 'idxs' as a list"}), 400
    if not seq_len or not isinstance(seq_len, int):
        return jsonify({"error": "Provide 'seq_len' as an int"}), 400

    all_pred_prices = []
    all_pred_sigmas = []
    all_pi = []

    for idx in idxs:
        seq_df = df.iloc[idx - seq_len + 1 : idx + 1]
        last_close = seq_df.iloc[-1]['close']
        X_np = one_sample(seq_df)
        X_t = torch.from_numpy(X_np.astype(np.float32)).unsqueeze(0)

        with torch.no_grad():
            mdn_out = model(X_t)

        pi    = mdn_out['pi'][0, -1].cpu().numpy()
        mu    = mdn_out['mu'][0, -1].cpu().numpy()
        sigma = mdn_out['sigma'][0, -1].cpu().numpy()

        all_pred_prices.append((last_close * np.exp(mu)).tolist())
        all_pred_sigmas.append((last_close * sigma).tolist())
        all_pi.append(pi.tolist())

    return jsonify({
        'pred_prices': all_pred_prices,
        'pred_sigmas': all_pred_sigmas,
        'pi': all_pi
    })



if __name__ == '__main__':
    app.run(debug=True,use_reloader=False)


 * Serving Flask app '__main__'
 * Debug mode: on


 * Running on http://127.0.0.1:5000
Press CTRL+C to quit
127.0.0.1 - - [03/Sep/2025 03:55:33] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [03/Sep/2025 03:55:33] "GET /candles HTTP/1.1" 200 -
127.0.0.1 - - [03/Sep/2025 03:55:34] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [03/Sep/2025 03:55:34] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [03/Sep/2025 03:55:34] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [03/Sep/2025 03:55:34] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [03/Sep/2025 03:55:34] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [03/Sep/2025 03:55:39] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [03/Sep/2025 03:55:39] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [03/Sep/2025 03:55:39] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [03/Sep/2025 03:55:39] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [03/Sep/2025 03:55:39] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [03/Sep/2025 03:55:41] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [03/Sep/2025 03:55:41] "POST /predict HTTP/1.1" 200 -
