In [None]:
import pandas as pd
label = pd.read_csv("/home/iatell/projects/meta-learning/data/seq_line_labels.csv")
label["seq_len"] = label["endIndex"] - label["startIndex"]
label

In [2]:

import pandas as pd
df = pd.read_csv("/home/iatell/projects/meta-learning/data/Bitcoin_BTCUSDT_kaggle_1D_candles_prop.csv")
df

Unnamed: 0,timestamp,open,high,low,close,volume,upper_shadow,body,lower_shadow,Candle_Color,upper_body_ratio,lower_body_ratio,upper_lower_body_ratio
0,2018-01-01,13707.91,13818.55,12750.00,13380.00,8607.15640,0.076003,-0.225254,0.432772,1,0.337410,1.921259,0.175619
1,2018-01-02,13382.16,15473.49,12890.02,14675.11,20078.16540,0.540071,0.874627,0.332912,2,0.617487,0.380633,1.622262
2,2018-01-03,14690.00,15307.56,14150.00,14919.51,15905.48210,0.263644,0.155931,0.366880,2,1.690776,2.352839,0.718611
3,2018-01-04,14919.51,15280.00,13918.04,15059.54,25224.41500,0.150006,0.095280,0.681423,2,1.574377,5.000000,0.220136
4,2018-01-05,15059.56,17176.24,14600.00,16960.39,23251.35200,0.144690,1.274181,0.308056,2,0.113556,0.241768,0.469688
...,...,...,...,...,...,...,...,...,...,...,...,...,...
1599,2022-05-19,28715.33,30545.18,28691.38,30319.23,67877.36415,0.109006,0.773779,0.011554,2,0.140875,0.014932,5.000000
1600,2022-05-20,30319.22,30777.33,28730.00,29201.01,60517.25325,0.221063,-0.539597,0.227288,1,0.409682,0.421218,0.972612
1601,2022-05-21,29201.01,29656.18,28947.28,29445.06,20987.13124,0.103235,0.119338,0.124071,2,0.865069,1.039664,0.832066
1602,2022-05-22,29445.07,30487.99,29255.11,30293.94,36158.98748,0.095648,0.418411,0.093632,2,0.228598,0.223780,1.021531


# model


## two head lstm

In [None]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
from torch.nn.utils.rnn import pack_padded_sequence

class LSTMMultiRegressor(pl.LightningModule):
    def __init__(self, input_dim, hidden_dim, num_layers, max_len_y, lr=0.001, threshold=0.5):
        super().__init__()
        self.save_hyperparameters()

        self.lstm = nn.LSTM(
            input_size=input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
        )
        # Main regression output: predict all linePrices up to max_len_y
        self.fc_reg = nn.Linear(hidden_dim, max_len_y)
        # Length prediction branch: logits per possible line (max_len_y)
        self.fc_len = nn.Linear(hidden_dim, max_len_y)
        self.lr = lr
        self.threshold = threshold

        self.loss_fn_reg = nn.MSELoss(reduction="none")  # we'll mask padded values
        self.loss_fn_len = nn.BCEWithLogitsLoss()        # treat as multi-label classification

    def forward(self, x, lengths):
        x = x["main"] 
        packed = pack_padded_sequence(x, lengths.cpu(), batch_first=True, enforce_sorted=False)
        _, (hn, _) = self.lstm(packed)
        last_h = hn[-1]

        y_pred = self.fc_reg(last_h)      # regression outputs
        len_logits = self.fc_len(last_h)  # logits per possible line
        return y_pred, len_logits

    def training_step(self, batch, batch_idx):
        X, y, lengths = batch
        y_pred, len_logits = self(X, lengths)

        # --- Regression loss with masking ---
        mask = (y != 0).float()  # assume padding = 0
        loss_reg = (self.loss_fn_reg(y_pred, y) * mask).sum() / mask.sum()

        # --- Length loss ---
        target_lengths = torch.zeros_like(len_logits, dtype=torch.float32)
        for i, l in enumerate(lengths):
            target_lengths[i, :l] = 1.0   # first l positions are 1, rest are 0

        loss_len = self.loss_fn_len(len_logits, target_lengths)

        loss = loss_reg + 0.1 * loss_len
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_loss_reg", loss_reg, prog_bar=True)
        self.log("train_loss_len", loss_len, prog_bar=True)

        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    def predict_length(self, len_logits):
        """
        Convert logits to predicted number of lines using threshold.
        """
        probs = torch.sigmoid(len_logits)
        pred_len = (probs > self.threshold).sum(dim=1)
        return pred_len


## two head lstm sum of logits loss

In [5]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
from torch.nn.utils.rnn import pack_padded_sequence

class LSTMMultiRegressor(pl.LightningModule):
    def __init__(self, input_dim, hidden_dim, num_layers, max_len_y, lr=0.001, threshold=0.5):
        super().__init__()
        self.save_hyperparameters()

        self.lstm = nn.LSTM(
            input_size=input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
        )
        # Main regression output: predict all linePrices up to max_len_y
        self.fc_reg = nn.Linear(hidden_dim, max_len_y)
        # Length prediction branch: logits per possible line (max_len_y)
        self.fc_len = nn.Linear(hidden_dim, max_len_y)
        self.lr = lr
        self.threshold = threshold

        self.loss_fn_reg = nn.MSELoss(reduction="none")  # we'll mask padded values
        self.loss_fn_len = nn.BCEWithLogitsLoss()        # treat as multi-label classification

    def forward(self, x, lengths):
        x = x["main"] 
        packed = pack_padded_sequence(x, lengths.cpu(), batch_first=True, enforce_sorted=False)
        _, (hn, _) = self.lstm(packed)
        last_h = hn[-1]

        y_pred = self.fc_reg(last_h)      # regression outputs
        len_logits = self.fc_len(last_h)  # logits per possible line
        return y_pred, len_logits

    def training_step(self, batch, batch_idx):
        X, y, lengths = batch
        y_pred, len_logits = self(X, lengths)

        # --- Regression loss with masking ---
        mask = (y != 0).float()  # assume padding = 0
        loss_reg = (self.loss_fn_reg(y_pred, y) * mask).sum() / mask.sum()

        # --- Length loss ---
        target_lengths = torch.zeros_like(len_logits, dtype=torch.float32)
        for i, l in enumerate(lengths):
            target_lengths[i, :l] = 1.0   # first l positions are 1, rest are 0

        loss_len = self.loss_fn_len(len_logits, target_lengths)

        loss = loss_reg + 0.1 * loss_len
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    def predict_length(self, len_logits):
        """
        Convert logits to predicted number of lines using threshold.
        """
        probs = torch.sigmoid(len_logits)
        pred_len = (probs > self.threshold).sum(dim=1)
        return pred_len


## two head lstm soft thresholding

In [None]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
from torch.nn.utils.rnn import pack_padded_sequence
import torch.nn.functional as F

class LSTMMultiRegressor(pl.LightningModule):
    def __init__(self, input_dim, hidden_dim, num_layers, max_len_y, lr=0.001, threshold=0.5, k_soft=20.0):
        super().__init__()
        self.save_hyperparameters()

        self.lstm = nn.LSTM(
            input_size=input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
        )
        # Main regression output: predict all linePrices up to max_len_y
        self.fc_reg = nn.Linear(hidden_dim, max_len_y)
        # Length prediction branch: logits per possible line (max_len_y)
        self.fc_len = nn.Linear(hidden_dim, max_len_y)
        self.lr = lr
        self.threshold = threshold
        self.k_soft = k_soft  # slope for soft thresholding

        self.loss_fn_reg = nn.MSELoss(reduction="none")  # we'll mask padded values
        self.loss_fn_len = nn.BCEWithLogitsLoss()        # treat as multi-label classification

    def forward(self, x, lengths):
        x = x["main"] 
        packed = pack_padded_sequence(x, lengths.cpu(), batch_first=True, enforce_sorted=False)
        _, (hn, _) = self.lstm(packed)
        last_h = hn[-1]

        y_pred = self.fc_reg(last_h)      # regression outputs
        len_logits = self.fc_len(last_h)  # logits per possible line
        return y_pred, len_logits

    def training_step(self, batch, batch_idx):
        X, y, lengths = batch
        y_pred, len_logits = self(X, lengths)

        # --- Regression loss with masking ---
        mask = (y != 0).float()  # assume padding = 0
        loss_reg = (self.loss_fn_reg(y_pred, y) * mask).sum() / mask.sum()

        # --- Soft threshold length loss ---
        # Convert logits into soft 0-1 mask around threshold
        soft_mask = torch.sigmoid(self.k_soft * (len_logits - self.threshold))
        # Compute soft expected length per sample
        soft_len = soft_mask.sum(dim=1)
        # True length
        true_len = lengths.float()
        # MSE loss on soft expected length
        loss_len = F.mse_loss(soft_len, true_len)

        # Combine losses
        loss = loss_reg + 0.1 * loss_len
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    def predict_length(self, len_logits):
        """
        Convert logits to predicted number of lines using threshold.
        """
        probs = torch.sigmoid(len_logits)
        pred_len = (probs > self.threshold).sum(dim=1)
        return pred_len


## FNNCNN

In [None]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence

def mdn_split_params(raw_params, n_components):
    """
    raw_params: (B, 3K) tensor from mdn_head
    returns:
        pi    (B, K) mixture weights
        mu    (B, K) means
        sigma (B, K) std devs
    """
    B, threeK = raw_params.shape
    assert threeK == 3 * n_components

    raw = raw_params.view(B, n_components, 3)

    pi = raw[..., 0]                 # (B,K)
    mu = raw[..., 1]                 # (B,K)
    sigma = raw[..., 2]              # (B,K)

    pi = F.softmax(pi, dim=-1)       # weights sum to 1
    sigma = F.softplus(sigma) + 1e-4 # strictly positive
    return pi, mu, sigma


def mdn_nll_multitarget(y_line, pi, mu, sigma):
    """
    Negative log-likelihood for MDN with multiple valid targets per sample.
    Args:
        y_line : (B, L) padded targets (0 where invalid)
        pi, mu, sigma : (B, K) MDN params
    Returns:
        scalar loss
    """
    B, K = mu.shape
    losses = []

    for b in range(B):
        valid_y = y_line[b][y_line[b] > 0]  # (M,)
        if len(valid_y) == 0:
            continue

        # expand to (M, K)
        y_exp = valid_y.unsqueeze(-1).expand(-1, K)

        log_prob = -0.5 * ((y_exp - mu[b]) / (sigma[b] + 1e-8))**2 \
                   - torch.log(sigma[b] + 1e-8) \
                   - 0.5 * torch.log(torch.tensor(2.0 * torch.pi, device=y_line.device))

        log_mix = torch.log(pi[b] + 1e-8) + log_prob
        log_sum = torch.logsumexp(log_mix, dim=-1)  # (M,)

        losses.append(-log_sum.mean())

    if len(losses) == 0:
        return torch.tensor(0.0, device=y_line.device, requires_grad=True)

    return torch.stack(losses).mean()


class CNNLSTM_MDN(pl.LightningModule):
    def __init__(self, input_dim, hidden_dim=128, num_layers=1, hidden_features=64, out_features=32,
                 lr=1e-3, n_components=5, cnn_channels=64, dropout=0.1):
        super().__init__()
        self.save_hyperparameters()

        # Time-distributed feature extractor
        self.fc1 = nn.Linear(input_dim, hidden_features)
        self.ln1 = nn.LayerNorm(hidden_features) # ADDED: LayerNorm for time-step features
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.ln2 = nn.LayerNorm(out_features) # ADDED: LayerNorm

        # CNN feature extractors
        self.conv1 = nn.Conv1d(out_features, cnn_channels, kernel_size=1)
        self.bn1 = nn.BatchNorm1d(cnn_channels) # ADDED: BatchNorm for convolutional features
        self.conv3 = nn.Conv1d(out_features, cnn_channels, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm1d(cnn_channels) # ADDED: BatchNorm

        # Learnable mixer for CNN outputs
        self.mixer = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=1, bias=True)

        # LSTM for temporal dependency
        fused_dim = cnn_channels # Input to LSTM is the mixed CNN output
        self.lstm = nn.LSTM(fused_dim, hidden_dim, num_layers=num_layers,
                            batch_first=True, dropout=dropout if num_layers > 1 else 0)

        # MDN Head
        self.mdn_head = nn.Linear(hidden_dim, 3 * n_components)
        self.n_components = n_components
        self.lr = lr

        # Apply weight initialization
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.kaiming_normal_(module.weight, mode="fan_in", nonlinearity="relu")
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)
        elif isinstance(module, (nn.Conv1d, nn.Conv2d)):
            nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)

    def forward(self, X, lengths=None):
        # Input shape X["main"]: (B, T, F_in)
        x = X["main"] # REMOVED redundant transposes

        # 1. Time-distributed feature extraction
        x = self.fc1(x)
        x = F.relu(self.ln1(x)) # CHANGED: Apply LayerNorm before ReLU
        x = self.fc2(x)
        x = F.relu(self.ln2(x)) # CHANGED: Apply LayerNorm before ReLU

        # 2. CNN feature extraction
        x = x.transpose(1, 2)   # Shape: (B, C_in, L=T)
        x1 = F.relu(self.bn1(self.conv1(x))) # CHANGED: Apply BatchNorm before ReLU
        x3 = F.relu(self.bn3(self.conv3(x))) # CHANGED: Apply BatchNorm before ReLU

        # 3. Mix CNN outputs
        paired = torch.stack([x1, x3], dim=1) # Shape: (B, 2, C_out, L)
        mixed = self.mixer(paired)            # Shape: (B, 1, C_out, L)
        
        # Prepare for LSTM
        xf = mixed.squeeze(1).transpose(1, 2) # Shape: (B, L, C_out)

        # 4. LSTM for sequence summary
        if lengths is not None:
            packed_input = pack_padded_sequence(
                xf, lengths.cpu(), batch_first=True, enforce_sorted=False
            )
            _, (h_last, _) = self.lstm(packed_input)
        else:
            _, (h_last, _) = self.lstm(xf)
        
        last_h = h_last[-1] # Shape: (B, H)
        
        # 5. MDN head for distribution parameters
        raw = self.mdn_head(last_h)
        pi, mu, sigma = mdn_split_params(raw, self.n_components)
        return {"pi": pi, "mu": mu, "sigma": sigma}

    def training_step(self, batch, batch_idx):
        X, y_line, lengths = batch
        mdn = self(X, lengths)
        loss = mdn_nll_multitarget(y_line, mdn["pi"], mdn["mu"], mdn["sigma"])
        self.log("train/loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        X, y_line, lengths = batch
        mdn = self(X, lengths)
        loss = mdn_nll_multitarget(y_line, mdn["pi"], mdn["mu"], mdn["sigma"])
    # Log everything to progress bar
        self.log("val/loss", loss, prog_bar=True)
        self.log("val/pi_mean", mdn["pi"].mean(), prog_bar=True)
        self.log("val/pi_std", mdn["pi"].std(), prog_bar=True)
        self.log("val/mu_mean", mdn["mu"].mean(), prog_bar=True)
        self.log("val/mu_std", mdn["mu"].std(), prog_bar=True)
        self.log("val/sigma_mean", mdn["sigma"].mean(), prog_bar=True)
        self.log("val/sigma_std", mdn["sigma"].std(), prog_bar=True)
        
    def configure_optimizers(self): 
        return torch.optim.Adam(self.parameters(), lr=self.lr)
    # def configure_optimizers(self):
    #     optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
    #     scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    #         optimizer,
    #         mode="min",
    #         factor=0.2,   # Reduce LR by 80%
    #         patience=5,   # After 5 epochs of no val_loss improvement
    #         verbose=True
    #     )
    #     return {
    #         "optimizer": optimizer,
    #         "lr_scheduler": {
    #             "scheduler": scheduler,
    #             "monitor": "val/loss",  # Important!
    #         },
        # }


## CNNLSTM weightening

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
import pytorch_lightning as pl
# Your mdn_split_params function remains the same
def mdn_split_params(raw_params, n_components):
    B, threeK = raw_params.shape
    assert threeK == 3 * n_components
    raw = raw_params.view(B, n_components, 3)
    pi = F.softmax(raw[..., 0], dim=-1)
    mu = raw[..., 1]
    sigma = F.softplus(raw[..., 2]) + 1e-4
    return pi, mu, sigma

def weighted_mdn_nll(y_true, mdn_params, weights):
    total_loss = 0.0
    num_lines = y_true.shape[1]
    B = y_true.shape[0]

    # Keep track if any valid lines are found
    valid_line_found = False

    for i in range(num_lines):
        y_target = y_true[:, i:i+1]  # (B,1)
        pi, mu, sigma = mdn_params['pi'][i], mdn_params['mu'][i], mdn_params['sigma'][i]

        mask = (y_target != 0).squeeze()
        if mask.sum() == 0:
            continue

        valid_line_found = True
        y_target_masked = y_target[mask]
        pi_masked, mu_masked, sigma_masked = pi[mask], mu[mask], sigma[mask]

        dist = Normal(loc=mu_masked, scale=sigma_masked)
        log_prob = dist.log_prob(y_target_masked.expand_as(mu_masked))
        log_mix_prob = torch.log(pi_masked + 1e-8) + log_prob
        log_likelihood = torch.logsumexp(log_mix_prob, dim=1)
        line_loss = -log_likelihood.mean()
        total_loss += weights[i] * line_loss

    if not valid_line_found:
        # Avoid returning a Python float; create a tensor with requires_grad
        total_loss = torch.tensor(0.0, device=y_true.device, requires_grad=True)

    return total_loss


class CNNLSTM_MDN_MultiHead(pl.LightningModule):
    def __init__(self, input_dim, hidden_dim=128, num_layers=1, hidden_features=64, out_features=32,
                 lr=1e-3, n_components=5, cnn_channels=64, dropout=0.1, num_lines=9):
        super().__init__()
        self.save_hyperparameters()

        # --- Your CNN and LSTM base remains the same ---
        self.fc1 = nn.Linear(input_dim, hidden_features)
        self.ln1 = nn.LayerNorm(hidden_features)
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.ln2 = nn.LayerNorm(out_features)
        self.conv1 = nn.Conv1d(out_features, cnn_channels, kernel_size=1)
        self.bn1 = nn.BatchNorm1d(cnn_channels)
        self.conv3 = nn.Conv1d(out_features, cnn_channels, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm1d(cnn_channels)
        self.mixer = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=1, bias=True)
        fused_dim = cnn_channels
        self.lstm = nn.LSTM(fused_dim, hidden_dim, num_layers=num_layers,
                              batch_first=True, dropout=dropout if num_layers > 1 else 0)

        # === MODIFICATION: Create a list of MDN heads ===
        self.num_lines = num_lines
        self.mdn_heads = nn.ModuleList(
            [nn.Linear(hidden_dim, 3 * n_components) for _ in range(num_lines)]
        )

        self.n_components = n_components
        self.lr = lr

        # === Define importance weights here ===
        # Using exponential decay: w_i = 0.9^(i-1)
        weights = torch.tensor([0.9**i for i in range(self.num_lines)])
        self.register_buffer('loss_weights', weights)

        self.apply(self._init_weights)

    def _init_weights(self, module): # Your init function is fine
        if isinstance(module, nn.Linear):
            nn.init.kaiming_normal_(module.weight, mode="fan_in", nonlinearity="relu")
            if module.bias is not None: nn.init.constant_(module.bias, 0)
        elif isinstance(module, (nn.Conv1d, nn.Conv2d)):
            nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
            if module.bias is not None: nn.init.constant_(module.bias, 0)

    def forward(self, X, lengths=None):
        # --- Your forward pass for the base model is the same ---
        x = X["main"]
        x = F.relu(self.ln1(self.fc1(x)))
        x = F.relu(self.ln2(self.fc2(x)))
        x = x.transpose(1, 2)
        x1 = F.relu(self.bn1(self.conv1(x)))
        x3 = F.relu(self.bn3(self.conv3(x)))
        paired = torch.stack([x1, x3], dim=1)
        mixed = self.mixer(paired)
        xf = mixed.squeeze(1).transpose(1, 2)
        
        # We'll assume lengths is None for simplicity here, but your implementation is fine
        _, (h_last, _) = self.lstm(xf)
        last_h = h_last[-1]

        # === MODIFICATION: Get parameters from all heads ===
        all_params = {'pi': [], 'mu': [], 'sigma': []}
        for i in range(self.num_lines):
            raw_params = self.mdn_heads[i](last_h)
            pi, mu, sigma = mdn_split_params(raw_params, self.n_components)
            all_params['pi'].append(pi)
            all_params['mu'].append(mu)
            all_params['sigma'].append(sigma)

        return all_params
    

    # This would be inside your CNNLSTM_MDN_MultiHead class

    def training_step(self, batch, batch_idx):
        # Assuming your batch now provides a y tensor of shape (B, 9)
        # where y has your target line values, padded with -1.
        X, y, lengths = batch

        # Get the dictionary of parameter lists from the forward pass
        mdn_params = self(X, lengths)

        # Calculate loss using the new weighted function
        loss = weighted_mdn_nll(y, mdn_params, self.loss_weights)

        self.log("train/loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.lr)

    # NOTE: You'll also need a validation_step that mirrors the training_step logic
    def validation_step(self, batch, batch_idx):
        X, y, lengths = batch
        mdn_params = self(X, lengths)
        loss = weighted_mdn_nll(y, mdn_params, self.loss_weights)
        self.log("val/loss", loss, prog_bar=True)
        return loss

## LSTM weightening with pi order

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
import pytorch_lightning as pl


def mdn_split_params(raw_params, n_components):
    """
    Splits raw MDN output into mixture weights (pi), means (mu), and stds (sigma)
    """
    B, threeK = raw_params.shape
    assert threeK == 3 * n_components
    raw = raw_params.view(B, n_components, 3)
    pi = F.softmax(raw[..., 0], dim=-1)           # mixture probabilities
    mu = raw[..., 1]                              # means
    sigma = F.softplus(raw[..., 2]) + 1e-4       # stds
    return pi, mu, sigma


def weighted_mdn_nll(y_true, mdn_params, weights):
    """
    y_true: (B, num_lines)
    mdn_params: dict with 'pi', 'mu', 'sigma' each of shape (B, n_components)
    weights: (num_lines,) tensor
    """
    B, num_lines = y_true.shape
    pi, mu, sigma = mdn_params['pi'], mdn_params['mu'], mdn_params['sigma']  # (B, n_components)

    # Sort components by pi descending
    _, idx = torch.sort(pi, descending=True, dim=1)  # (B, n_components)

    total_loss = 0.0
    valid_line_found = False

    for i in range(num_lines):
        y_target = y_true[:, i]  # (B,)

        # Skip masked/padded targets
        mask = (y_target != 0)
        if mask.sum() == 0:
            continue
        valid_line_found = True

        # Select top pi component for this line
        top_mu = mu.gather(1, idx[:, i].unsqueeze(1)).squeeze(1)      # (B,)
        top_sigma = sigma.gather(1, idx[:, i].unsqueeze(1)).squeeze(1) # (B,)
        y_target_masked = y_target[mask]
        top_mu_masked = top_mu[mask]
        top_sigma_masked = top_sigma[mask]

        dist = Normal(top_mu_masked, top_sigma_masked)
        line_loss = -dist.log_prob(y_target_masked).mean()
        total_loss += weights[i] * line_loss

    if not valid_line_found:
        total_loss = torch.tensor(0.0, device=y_true.device, requires_grad=True)

    return total_loss


class CNNLSTM_MDN(pl.LightningModule):
    def __init__(self, input_dim, feature_eng=15,hidden_dim=32, n_components=9, num_lines=9, lr=1e-3, dropout=0.1):
        super().__init__()
        self.save_hyperparameters()
        self.num_lines = num_lines
        self.n_components = n_components
        self.lr = lr

        # Base network
        self.fc1 = nn.Linear(input_dim, feature_eng)
        self.ln1 = nn.LayerNorm(feature_eng)
        self.lstm = nn.LSTM(feature_eng, hidden_dim, batch_first=True, dropout=dropout)

        # Single MDN head predicting n_components Gaussians
        self.mdn_head = nn.Linear(hidden_dim, 3 * n_components)

        # Importance weights for lines
        weights = torch.tensor([0.9**i for i in range(num_lines)], dtype=torch.float)
        self.register_buffer("loss_weights", weights)

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.kaiming_normal_(module.weight, mode="fan_in", nonlinearity="relu")
            if module.bias is not None: nn.init.constant_(module.bias, 0)

    def forward(self, X, lengths=None):
        """
        X: (B, T, input_dim)
        """
        x = X["main"]
        x = F.relu(self.ln1(self.fc1(x)))
        
        if lengths is not None:
            x = nn.utils.rnn.pack_padded_sequence(x, lengths.cpu(), batch_first=True, enforce_sorted=False)
            _, (h_last, _) = self.lstm(x)
        else:
            _, (h_last, _) = self.lstm(x)

        last_h = h_last[-1]  # (B, hidden_dim)
        raw_params = self.mdn_head(last_h)  # (B, 3*n_components)
        pi, mu, sigma = mdn_split_params(raw_params, self.n_components)
        return {"pi": pi, "mu": mu, "sigma": sigma}

    def training_step(self, batch, batch_idx):
        X, y, lengths = batch
        mdn_params = self(X)
        loss = weighted_mdn_nll(y, mdn_params, self.loss_weights)
        self.log("train/loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        X, y, lengths = batch
        mdn_params = self(X)
        loss = weighted_mdn_nll(y, mdn_params, self.loss_weights)
        self.log("val/loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.lr)


## CNNLSTM weightening order

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.distributions import Normal

def mdn_split_params(raw_params, n_components):
    """
    Splits raw MDN output into mixture weights (pi), means (mu), and stds (sigma)
    """
    B, threeK = raw_params.shape
    assert threeK == 3 * n_components
    raw = raw_params.view(B, n_components, 3)
    pi = F.softmax(raw[..., 0], dim=-1)           # mixture probabilities
    mu = raw[..., 1]                              # means
    sigma = F.softplus(raw[..., 2]) + 1e-4       # stds
    return pi, mu, sigma


def weighted_mdn_nll(y_true, mdn_params, weights):
    """
    y_true: (B, num_lines)
    mdn_params: dict with 'pi', 'mu', 'sigma' each of shape (B, n_components)
    weights: (num_lines,) tensor
    """
    B, num_lines = y_true.shape
    pi, mu, sigma = mdn_params['pi'], mdn_params['mu'], mdn_params['sigma']  # (B, n_components)

    # Sort components by pi descending
    _, idx = torch.sort(pi, descending=True, dim=1)  # (B, n_components)

    total_loss = 0.0
    valid_line_found = False

    for i in range(num_lines):
        y_target = y_true[:, i]  # (B,)

        # Skip masked/padded targets
        mask = (y_target != 0)
        if mask.sum() == 0:
            continue
        valid_line_found = True

        # Select top pi component for this line
        top_mu = mu.gather(1, idx[:, i].unsqueeze(1)).squeeze(1)      # (B,)
        top_sigma = sigma.gather(1, idx[:, i].unsqueeze(1)).squeeze(1) # (B,)
        y_target_masked = y_target[mask]
        top_mu_masked = top_mu[mask]
        top_sigma_masked = top_sigma[mask]

        dist = Normal(top_mu_masked, top_sigma_masked)
        line_loss = -dist.log_prob(y_target_masked).mean()
        total_loss += weights[i] * line_loss

    if not valid_line_found:
        total_loss = torch.tensor(0.0, device=y_true.device, requires_grad=True)

    return total_loss


class cnn_lstm(pl.LightningModule):
    def __init__(self, input_dim, feature_eng=15, hidden_dim=32, n_components=9, num_lines=9, lr=1e-3, dropout=0.1):
        super().__init__()
        self.save_hyperparameters()
        self.num_lines = num_lines
        self.n_components = n_components
        self.lr = lr

        # Base feature projection
        self.fc1 = nn.Linear(input_dim, feature_eng)
        self.ln1 = nn.LayerNorm(feature_eng)

        # Parallel conv1d branches
        self.k1 = nn.Conv1d(feature_eng, feature_eng, kernel_size=1, padding=0)
        self.k3 = nn.Conv1d(feature_eng, feature_eng, kernel_size=3, padding=1)

        # Fusion via conv2d
        # Input channels = 2 (from k1 + k3), Output = 1, kernel size (1,1) to fuse
        self.fusion_conv2d = nn.Conv2d(2, 1, kernel_size=(1, 1))

        # LSTM
        self.lstm = nn.LSTM(feature_eng, hidden_dim, batch_first=True, dropout=dropout)

        # Single MDN head predicting n_components Gaussians
        self.mdn_head = nn.Linear(hidden_dim, 3 * n_components)

        # Importance weights for lines
        weights = torch.tensor([0.9**i for i in range(num_lines)], dtype=torch.float)
        self.register_buffer("loss_weights", weights)

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.kaiming_normal_(module.weight, mode="fan_in", nonlinearity="relu")
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)

    def forward(self, X, lengths=None):
        """
        X: dict with key "main", value shape (B, T, input_dim)
        """
        x = X["main"]  # (B, T, input_dim)
        B, T, _ = x.shape

        # Fully connected projection
        x = F.relu(self.ln1(self.fc1(x)))  # (B, T, F)

        # Conv1d expects (B, F, T)
        x_cnn = x.transpose(1, 2)  # (B, F, T)

        # Parallel convs
        x1 = self.k1(x_cnn)  # (B, F, T)
        x3 = self.k3(x_cnn)  # (B, F, T)

        # Stack into 2-channel feature map
        stacked = torch.stack([x1, x3], dim=1)  # (B, 2, F, T)

        # Fuse with conv2d → (B, 1, F, T)
        fused = self.fusion_conv2d(stacked).squeeze(1)  # (B, F, T)

        # Back to (B, T, F)
        fused = fused.transpose(1, 2)

        # LSTM with packed sequence
        if lengths is not None:
            packed = nn.utils.rnn.pack_padded_sequence(fused, lengths.cpu(), batch_first=True, enforce_sorted=False)
            _, (h_last, _) = self.lstm(packed)
        else:
            _, (h_last, _) = self.lstm(fused)

        last_h = h_last[-1]  # (B, hidden_dim)
        raw_params = self.mdn_head(last_h)  # (B, 3 * n_components)

        # Assume you have mdn_split_params(pi, mu, sigma)
        pi, mu, sigma = mdn_split_params(raw_params, self.n_components)
        return {"pi": pi, "mu": mu, "sigma": sigma}
    
    def training_step(self, batch, batch_idx):
        X, y, lengths = batch
        mdn_params = self(X)
        loss = weighted_mdn_nll(y, mdn_params, self.loss_weights)
        self.log("train/loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        X, y, lengths = batch
        mdn_params = self(X)
        loss = weighted_mdn_nll(y, mdn_params, self.loss_weights)
        self.log("val/loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.lr)

## CNNLSTM weightening with sigma confidance

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
import pytorch_lightning as pl
# Your mdn_split_params function remains the same
def mdn_split_params(raw_params, n_components):
    B, threeK = raw_params.shape
    assert threeK == 3 * n_components
    raw = raw_params.view(B, n_components, 3)
    pi = F.softmax(raw[..., 0], dim=-1)
    mu = raw[..., 1]
    sigma = F.softplus(raw[..., 2]) + 1e-4
    return pi, mu, sigma

def weighted_mdn_nll_with_sigma_penalty(y_true, mdn_params, weights, lambda_sigma=0.01):
    """
    Calculates weighted MDN NLL and adds a penalty for large sigmas.
    
    Args:
        lambda_sigma (float): The strength of the sigma penalty.
    """
    total_loss = 0.0
    num_lines = y_true.shape[1]

    for i in range(num_lines):
        y_target = y_true[:, i:i+1]
        pi, mu, sigma = mdn_params['pi'][i], mdn_params['mu'][i], mdn_params['sigma'][i]
        mask = (y_target != -1).squeeze()

        if mask.sum() == 0:
            continue

        y_target_masked = y_target[mask]
        pi_masked, mu_masked, sigma_masked = pi[mask], mu[mask], sigma[mask]
        
        # --- 1. NLL Loss Calculation (same as before) ---
        dist = Normal(loc=mu_masked, scale=sigma_masked)
        log_prob = dist.log_prob(y_target_masked.expand_as(mu_masked))
        log_mix_prob = torch.log(pi_masked + 1e-8) + log_prob
        log_likelihood = torch.logsumexp(log_mix_prob, dim=1)
        line_nll_loss = -log_likelihood.mean()

        # --- 2. NEW: Sigma Penalty Calculation ---
        # We penalize the mean of the sigmas for the most likely component
        # This focuses the penalty on the component the model actually uses
        most_likely_idx = torch.argmax(pi_masked, dim=1)
        most_likely_sigma = sigma_masked.gather(1, most_likely_idx.unsqueeze(1)).squeeze()
        sigma_penalty = torch.mean(most_likely_sigma)
        
        # --- 3. Combine and Weight ---
        combined_line_loss = line_nll_loss + (lambda_sigma * sigma_penalty)
        total_loss += weights[i] * combined_line_loss

    return total_loss

# In your training_step, you would call this new function:
# loss = weighted_mdn_nll_with_sigma_penalty(y, mdn_params, self.loss_weights, lambda_sigma=0.01)

class CNNLSTM_MDN_MultiHead(pl.LightningModule):
    def __init__(self, input_dim, hidden_dim=128, num_layers=1, hidden_features=64, out_features=32,
                 lr=1e-3, n_components=5, cnn_channels=64, dropout=0.1, num_lines=9):
        super().__init__()
        self.save_hyperparameters()

        # --- Your CNN and LSTM base remains the same ---
        self.fc1 = nn.Linear(input_dim, hidden_features)
        self.ln1 = nn.LayerNorm(hidden_features)
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.ln2 = nn.LayerNorm(out_features)
        self.conv1 = nn.Conv1d(out_features, cnn_channels, kernel_size=1)
        self.bn1 = nn.BatchNorm1d(cnn_channels)
        self.conv3 = nn.Conv1d(out_features, cnn_channels, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm1d(cnn_channels)
        self.mixer = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=1, bias=True)
        fused_dim = cnn_channels
        self.lstm = nn.LSTM(fused_dim, hidden_dim, num_layers=num_layers,
                              batch_first=True, dropout=dropout if num_layers > 1 else 0)

        # === MODIFICATION: Create a list of MDN heads ===
        self.num_lines = num_lines
        self.mdn_heads = nn.ModuleList(
            [nn.Linear(hidden_dim, 3 * n_components) for _ in range(num_lines)]
        )

        self.n_components = n_components
        self.lr = lr

        # === Define importance weights here ===
        # Using exponential decay: w_i = 0.9^(i-1)
        weights = torch.tensor([0.9**i for i in range(self.num_lines)])
        self.register_buffer('loss_weights', weights)

        self.apply(self._init_weights)

    def _init_weights(self, module): # Your init function is fine
        if isinstance(module, nn.Linear):
            nn.init.kaiming_normal_(module.weight, mode="fan_in", nonlinearity="relu")
            if module.bias is not None: nn.init.constant_(module.bias, 0)
        elif isinstance(module, (nn.Conv1d, nn.Conv2d)):
            nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
            if module.bias is not None: nn.init.constant_(module.bias, 0)

    def forward(self, X, lengths=None):
        # --- Your forward pass for the base model is the same ---
        x = X["main"]
        x = F.relu(self.ln1(self.fc1(x)))
        x = F.relu(self.ln2(self.fc2(x)))
        x = x.transpose(1, 2)
        x1 = F.relu(self.bn1(self.conv1(x)))
        x3 = F.relu(self.bn3(self.conv3(x)))
        paired = torch.stack([x1, x3], dim=1)
        mixed = self.mixer(paired)
        xf = mixed.squeeze(1).transpose(1, 2)
        
        # We'll assume lengths is None for simplicity here, but your implementation is fine
        _, (h_last, _) = self.lstm(xf)
        last_h = h_last[-1]

        # === MODIFICATION: Get parameters from all heads ===
        all_params = {'pi': [], 'mu': [], 'sigma': []}
        for i in range(self.num_lines):
            raw_params = self.mdn_heads[i](last_h)
            pi, mu, sigma = mdn_split_params(raw_params, self.n_components)
            all_params['pi'].append(pi)
            all_params['mu'].append(mu)
            all_params['sigma'].append(sigma)

        return all_params
    

    # This would be inside your CNNLSTM_MDN_MultiHead class

    def training_step(self, batch, batch_idx):
        # Assuming your batch now provides a y tensor of shape (B, 9)
        # where y has your target line values, padded with -1.
        X, y, lengths = batch

        # Get the dictionary of parameter lists from the forward pass
        mdn_params = self(X, lengths)

        # Calculate loss using the new weighted function
        loss = weighted_mdn_nll_with_sigma_penalty(y, mdn_params, self.loss_weights)

        self.log("train/loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.lr)

    # NOTE: You'll also need a validation_step that mirrors the training_step logic
    def validation_step(self, batch, batch_idx):
        X, y, lengths = batch
        mdn_params = self(X, lengths)
        loss = weighted_mdn_nll_with_sigma_penalty(y, mdn_params, self.loss_weights)
        self.log("val/loss", loss, prog_bar=True)
        return loss

## CNNlSTM

In [2]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence

def mdn_split_params(raw_params, n_components):
    """
    raw_params: (B, 3K) tensor from mdn_head
    returns:
        pi    (B, K) mixture weights
        mu    (B, K) means
        sigma (B, K) std devs
    """
    B, threeK = raw_params.shape
    assert threeK == 3 * n_components

    raw = raw_params.view(B, n_components, 3)

    pi = raw[..., 0]                 # (B,K)
    mu = raw[..., 1]                 # (B,K)
    sigma = raw[..., 2]              # (B,K)

    pi = F.softmax(pi, dim=-1)       # weights sum to 1
    sigma = F.softplus(sigma) + 1e-4 # strictly positive
    return pi, mu, sigma


def mdn_nll_multitarget(y_line, pi, mu, sigma):
    """
    Negative log-likelihood for MDN with multiple valid targets per sample.
    Args:
        y_line : (B, L) padded targets (0 where invalid)
        pi, mu, sigma : (B, K) MDN params
    Returns:
        scalar loss
    """
    B, K = mu.shape
    losses = []

    for b in range(B):
        valid_y = y_line[b][y_line[b] > 0]  # (M,)
        if len(valid_y) == 0:
            continue

        # expand to (M, K)
        y_exp = valid_y.unsqueeze(-1).expand(-1, K)

        log_prob = -0.5 * ((y_exp - mu[b]) / (sigma[b] + 1e-8))**2 \
                   - torch.log(sigma[b] + 1e-8) \
                   - 0.5 * torch.log(torch.tensor(2.0 * torch.pi, device=y_line.device))

        log_mix = torch.log(pi[b] + 1e-8) + log_prob
        log_sum = torch.logsumexp(log_mix, dim=-1)  # (M,)

        losses.append(-log_sum.mean())

    if len(losses) == 0:
        return torch.tensor(0.0, device=y_line.device, requires_grad=True)

    return torch.stack(losses).mean()


class CNNLSTM_MDN(pl.LightningModule):
    def __init__(self, input_dim, hidden_dim=128, num_layers=1, hidden_features=64, out_features=32,
                 lr=1e-3, n_components=5, cnn_channels=64, dropout=0.1):
        super().__init__()
        self.save_hyperparameters()

        # Time-distributed feature extractor
        self.fc1 = nn.Linear(input_dim, hidden_features)
        self.ln1 = nn.LayerNorm(hidden_features) # ADDED: LayerNorm for time-step features
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.ln2 = nn.LayerNorm(out_features) # ADDED: LayerNorm

        # CNN feature extractors
        self.conv1 = nn.Conv1d(out_features, cnn_channels, kernel_size=1)
        self.bn1 = nn.BatchNorm1d(cnn_channels) # ADDED: BatchNorm for convolutional features
        self.conv3 = nn.Conv1d(out_features, cnn_channels, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm1d(cnn_channels) # ADDED: BatchNorm

        # Learnable mixer for CNN outputs
        self.mixer = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=1, bias=True)

        # LSTM for temporal dependency
        fused_dim = cnn_channels # Input to LSTM is the mixed CNN output
        self.lstm = nn.LSTM(fused_dim, hidden_dim, num_layers=num_layers,
                            batch_first=True, dropout=dropout if num_layers > 1 else 0)

        # MDN Head
        self.mdn_head = nn.Linear(hidden_dim, 3 * n_components)
        self.n_components = n_components
        self.lr = lr

        # Apply weight initialization
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.kaiming_normal_(module.weight, mode="fan_in", nonlinearity="relu")
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)
        elif isinstance(module, (nn.Conv1d, nn.Conv2d)):
            nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)

    def forward(self, X, lengths=None):
        # Input shape X["main"]: (B, T, F_in)
        x = X["main"] # REMOVED redundant transposes

        # 1. Time-distributed feature extraction
        x = self.fc1(x)
        x = F.relu(self.ln1(x)) # CHANGED: Apply LayerNorm before ReLU
        x = self.fc2(x)
        x = F.relu(self.ln2(x)) # CHANGED: Apply LayerNorm before ReLU
        # 2. CNN feature extraction
        x = x.transpose(1, 2)   # Shape: (B, C_in, L=T)
        x1 = F.relu(self.bn1(self.conv1(x))) # CHANGED: Apply BatchNorm before ReLU
        x3 = F.relu(self.bn3(self.conv3(x))) # CHANGED: Apply BatchNorm before ReLU

        # 3. Mix CNN outputs
        paired = torch.stack([x1, x3], dim=1) # Shape: (B, 2, C_out, L)
        mixed = self.mixer(paired)            # Shape: (B, 1, C_out, L)
        
        # Prepare for LSTM
        xf = mixed.squeeze(1).transpose(1, 2) # Shape: (B, L, C_out)

        # 4. LSTM for sequence summary
        if lengths is not None:
            packed_input = pack_padded_sequence(
                xf, lengths.cpu(), batch_first=True, enforce_sorted=False
            )
            _, (h_last, _) = self.lstm(packed_input)
        else:
            _, (h_last, _) = self.lstm(xf)
        
        last_h = h_last[-1] # Shape: (B, H)
        
        # 5. MDN head for distribution parameters
        raw = self.mdn_head(last_h)
        pi, mu, sigma = mdn_split_params(raw, self.n_components)
        return {"pi": pi, "mu": mu, "sigma": sigma}

    def training_step(self, batch, batch_idx):
        X, y_line, lengths = batch
        mdn = self(X, lengths)
        loss = mdn_nll_multitarget(y_line, mdn["pi"], mdn["mu"], mdn["sigma"])
        self.log("train/loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        X, y_line, lengths = batch
        mdn = self(X, lengths)
        loss = mdn_nll_multitarget(y_line, mdn["pi"], mdn["mu"], mdn["sigma"])
    # Log everything to progress bar
        self.log("val/loss", loss, prog_bar=True)
        self.log("val/pi_mean", mdn["pi"].mean(), prog_bar=True)
        self.log("val/pi_std", mdn["pi"].std(), prog_bar=True)
        self.log("val/mu_mean", mdn["mu"].mean(), prog_bar=True)
        self.log("val/mu_std", mdn["mu"].std(), prog_bar=True)
        self.log("val/sigma_mean", mdn["sigma"].mean(), prog_bar=True)
        self.log("val/sigma_std", mdn["sigma"].std(), prog_bar=True)
        
    # Inside your CNNLSTM_MDN class
    # def configure_optimizers(self):
    #     optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        
    #     # Define the scheduler
    #     scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    #         optimizer,
    #         mode='min',      # We want to minimize the validation loss
    #         factor=0.5,      # Reduce LR by 50% (1.0 -> 0.2)
    #         patience=10,      # Wait 5 validation epochs with no improvement before reducing
    #         verbose=True
    #     )
        
    #     return {
    #         "optimizer": optimizer,
    #         "lr_scheduler": {
    #             "scheduler": scheduler,
    #             "monitor": "val/loss",  # The metric to watch
    #         },
    #     }
    
    def configure_optimizers(self): 
        return torch.optim.Adam(self.parameters(), lr=self.lr)
    # def configure_optimizers(self):
    #     optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
    #     scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    #         optimizer,
    #         mode="min",
    #         factor=0.2,   # Reduce LR by 80%
    #         patience=5,   # After 5 epochs of no val_loss improvement
    #         verbose=True
    #     )
    #     return {
    #         "optimizer": optimizer,
    #         "lr_scheduler": {
    #             "scheduler": scheduler,
    #             "monitor": "val/loss",  # Important!
    #         },
        # }


## CNNLSTM scalie

In [2]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence

def mdn_split_params(raw_params, n_components, mu_scale=10, mu_bias=.9, sigma_scale=10.0):
    """
    Split raw MDN parameters into (pi, mu, sigma).

    Args:
        raw_params: (B, 3 * K) from the network
        n_components: number of mixture components
        mu_scale: scaling factor for mu (default 1.0 = no scaling)
        mu_bias: shift/bias applied after scaling
        sigma_scale: scaling factor for sigma (default 10.0)
    """
    B = raw_params.size(0)
    raw = raw_params.view(B, n_components, 3)

    pi_raw = raw[..., 0]
    mu_raw = raw[..., 1]
    sigma_raw = raw[..., 2]

    pi = F.softmax(pi_raw, dim=-1)
    mu = mu_raw / mu_scale + mu_bias
    sigma = F.softplus(sigma_raw / sigma_scale) + 1e-4

    return pi, mu, sigma


def mdn_nll_multitarget(y_line, pi, mu, sigma):
    """
    Negative log-likelihood for MDN with multiple valid targets per sample.
    Args:
        y_line : (B, L) padded targets (0 where invalid)
        pi, mu, sigma : (B, K) MDN params
    Returns:
        scalar loss
    """
    B, K = mu.shape
    losses = []

    for b in range(B):
        valid_y = y_line[b][y_line[b] > 0]  # (M,)
        if len(valid_y) == 0:
            continue

        # expand to (M, K)
        y_exp = valid_y.unsqueeze(-1).expand(-1, K)

        log_prob = -0.5 * ((y_exp - mu[b]) / (sigma[b] + 1e-8))**2 \
                   - torch.log(sigma[b] + 1e-8) \
                   - 0.5 * torch.log(torch.tensor(2.0 * torch.pi, device=y_line.device))

        log_mix = torch.log(pi[b] + 1e-8) + log_prob
        log_sum = torch.logsumexp(log_mix, dim=-1)  # (M,)

        losses.append(-log_sum.mean())

    if len(losses) == 0:
        return torch.tensor(0.0, device=y_line.device, requires_grad=True)

    return torch.stack(losses).mean()


class CNNLSTM_MDN(pl.LightningModule):
    def __init__(self, input_dim, hidden_dim=128, num_layers=1, hidden_features=64, out_features=32,
                 lr=1e-3, n_components=5, cnn_channels=64, dropout=0.1):
        super().__init__()
        self.save_hyperparameters()

        # Time-distributed feature extractor
        self.fc1 = nn.Linear(input_dim, hidden_features)
        self.ln1 = nn.LayerNorm(hidden_features) # ADDED: LayerNorm for time-step features
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.ln2 = nn.LayerNorm(out_features) # ADDED: LayerNorm

        # CNN feature extractors
        self.conv1 = nn.Conv1d(out_features, cnn_channels, kernel_size=1)
        self.bn1 = nn.BatchNorm1d(cnn_channels) # ADDED: BatchNorm for convolutional features
        self.conv3 = nn.Conv1d(out_features, cnn_channels, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm1d(cnn_channels) # ADDED: BatchNorm

        # Learnable mixer for CNN outputs
        self.mixer = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=1, bias=True)

        # LSTM for temporal dependency
        fused_dim = cnn_channels # Input to LSTM is the mixed CNN output
        self.lstm = nn.LSTM(fused_dim, hidden_dim, num_layers=num_layers,
                            batch_first=True, dropout=dropout if num_layers > 1 else 0)

        # MDN Head
        self.mdn_head = nn.Linear(hidden_dim, 3 * n_components)
        self.n_components = n_components
        self.lr = lr

        # Apply weight initialization
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.kaiming_normal_(module.weight, mode="fan_in", nonlinearity="relu")
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)
        elif isinstance(module, (nn.Conv1d, nn.Conv2d)):
            nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)

    def forward(self, X, lengths=None):
        # Input shape X["main"]: (B, T, F_in)
        x = X["main"] 

        # --- Debug print first candle ---
        # if x.ndim == 3:  # batched: (B, T, F)
        #     first_candle = x[0, 0, :]   # first sample, first time step, all features
        #     print("First candle features:", first_candle.detach().cpu().numpy())
        # elif x.ndim == 2:  # single sequence: (T, F)
        #     first_candle = x[0, :]      # first time step, all features
        #     print("First candle features:", first_candle.detach().cpu().numpy())
        # else:
        #     print("Unexpected shape for x:", x.shape)
        # 1. Time-distributed feature extraction
        x = self.fc1(x)
        x = F.relu(self.ln1(x)) # CHANGED: Apply LayerNorm before ReLU
        x = self.fc2(x)
        x = F.relu(self.ln2(x)) # CHANGED: Apply LayerNorm before ReLU

        # 2. CNN feature extraction
        x = x.transpose(1, 2)   # Shape: (B, C_in, L=T)
        x1 = F.relu(self.bn1(self.conv1(x))) # CHANGED: Apply BatchNorm before ReLU
        x3 = F.relu(self.bn3(self.conv3(x))) # CHANGED: Apply BatchNorm before ReLU

        # 3. Mix CNN outputs
        paired = torch.stack([x1, x3], dim=1) # Shape: (B, 2, C_out, L)
        mixed = self.mixer(paired)            # Shape: (B, 1, C_out, L)
        
        # Prepare for LSTM
        xf = mixed.squeeze(1).transpose(1, 2) # Shape: (B, L, C_out)

        # 4. LSTM for sequence summary
        if lengths is not None:
            packed_input = pack_padded_sequence(
                xf, lengths.cpu(), batch_first=True, enforce_sorted=False
            )
            _, (h_last, _) = self.lstm(packed_input)
        else:
            _, (h_last, _) = self.lstm(xf)
        
        last_h = h_last[-1] # Shape: (B, H)
        
        # 5. MDN head for distribution parameters
        raw = self.mdn_head(last_h)
        pi, mu, sigma = mdn_split_params(raw, self.n_components)
        return {"pi": pi, "mu": mu, "sigma": sigma}

    def training_step(self, batch, batch_idx):
        X, y_line, lengths = batch
        mdn = self(X, lengths)
        loss = mdn_nll_multitarget(y_line, mdn["pi"], mdn["mu"], mdn["sigma"])
        self.log("train/loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        X, y_line, lengths = batch
        mdn = self(X, lengths)
        loss = mdn_nll_multitarget(y_line, mdn["pi"], mdn["mu"], mdn["sigma"])
    # Log everything to progress bar
        self.log("val/loss", loss, prog_bar=True)
        self.log("val/pi_mean", mdn["pi"].mean(), prog_bar=True)
        self.log("val/pi_std", mdn["pi"].std(), prog_bar=True)
        self.log("val/mu_mean", mdn["mu"].mean(), prog_bar=True)
        self.log("val/mu_std", mdn["mu"].std(), prog_bar=True)
        self.log("val/sigma_mean", mdn["sigma"].mean(), prog_bar=True)
        self.log("val/sigma_std", mdn["sigma"].std(), prog_bar=True)
        
    # # Inside your CNNLSTM_MDN class
    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=1e-4)
    # def configure_optimizers(self):
    #     optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        
    #     # Define the scheduler
    #     scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    #         optimizer,
    #         mode='min',      # We want to minimize the validation loss
    #         factor=0.5,      # Reduce LR by 80% (1.0 -> 0.2)
    #         patience=10,      # Wait 5 validation epochs with no improvement before reducing
    #         verbose=True
    #     )
        
    #     return {
    #         "optimizer": optimizer,
    #         "lr_scheduler": {
    #             "scheduler": scheduler,
    #             "monitor": "val/loss",  # The metric to watch
    #         },
    #     }
    # def configure_optimizers(self):
    #     return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)


## CNNtransformer wheightening order

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
import pytorch_lightning as pl
import math

# --- Helper Functions and Modules ---

def mdn_split_params(raw_params, n_components):
    """
    Splits raw MDN output into mixture weights (pi), means (mu), and stds (sigma).
    This function is used by each individual MDN head.
    """
    B, threeK = raw_params.shape
    assert threeK == 3 * n_components
    raw = raw_params.view(B, n_components, 3)
    pi = F.softmax(raw[..., 0], dim=-1)
    mu = raw[..., 1]
    sigma = F.softplus(raw[..., 2]) + 1e-6 # Added a small epsilon for stability
    return pi, mu, sigma

class PositionalEncoding(nn.Module):
    """
    Injects positional information into the input sequence for the Transformer.
    """
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 500):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

# --- Weighted Loss Function for Multi-Head Architecture ---

def weighted_mdn_nll_multihead(y_true, mdn_params_list, weights, padding_value=-1):
    """
    Calculates the weighted negative log-likelihood for a multi-headed MDN.
    This version correctly handles multiple heads and calculates the full NLL for each.

    Args:
        y_true (Tensor): Padded target values, shape (B, num_lines).
        mdn_params_list (list): A list of dicts, one for each head.
        weights (Tensor): A 1D tensor of importance weights, shape (num_lines,).
        padding_value (int): Value used for padding in y_true.
    """
    total_loss = 0.0
    num_lines = y_true.shape[1]
    
    for i in range(num_lines):
        y_target = y_true[:, i:i+1]
        pi, mu, sigma = mdn_params_list[i]['pi'], mdn_params_list[i]['mu'], mdn_params_list[i]['sigma']

        # Create a mask for valid (non-padded) targets for this line
        mask = (y_target != padding_value).squeeze()

        if mask.sum() == 0:  # Skip if no valid targets for this line in the batch
            continue

        # Select only the valid data for this line's loss calculation
        y_target_masked = y_target[mask]
        pi_masked, mu_masked, sigma_masked = pi[mask], mu[mask], sigma[mask]
        
        # Use torch.distributions for a clean and stable calculation
        dist = Normal(loc=mu_masked, scale=sigma_masked)
        
        # Calculate log probabilities of the target values in each Gaussian component
        log_prob = dist.log_prob(y_target_masked.expand_as(mu_masked))
        
        # Mix the probabilities using the mixture weights (pi)
        log_mix_prob = torch.log(pi_masked + 1e-8) + log_prob
        
        # Use logsumexp for numerical stability to get the log-likelihood
        log_likelihood = torch.logsumexp(log_mix_prob, dim=1)
        
        # Calculate the mean negative log-likelihood for this line
        line_loss = -log_likelihood.mean()

        # Apply the importance weight and add to total loss
        total_loss += weights[i] * line_loss

    # If no valid lines were found in the entire batch, return a zero tensor
    if not isinstance(total_loss, torch.Tensor):
        return torch.tensor(0.0, device=y_true.device, requires_grad=True)
        
    return total_loss

# --- The CNN-Transformer Model ---

class cnn_transformer(pl.LightningModule):
    def __init__(self, input_dim, cnn_out_channels=64, d_model=128, nhead=4, num_encoder_layers=2,
                 n_components=9, num_lines=9, lr=1e-4, dropout=0.1):
        super().__init__()
        self.save_hyperparameters()
        self.num_lines = num_lines
        self.n_components = n_components
        self.lr = lr
        
        # 1. CNN Feature Extractor Block
        self.cnn_extractor = nn.Sequential(
            nn.Conv1d(input_dim, cnn_out_channels, kernel_size=3, padding=1),
            nn.BatchNorm1d(cnn_out_channels),
            nn.ReLU(),
            nn.Conv1d(cnn_out_channels, d_model, kernel_size=3, padding=1),
            nn.BatchNorm1d(d_model),
            nn.ReLU()
        )
        
        # 2. Positional Encoding
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        
        # 3. Transformer Encoder
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dropout=dropout, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
        
        # 4. Multi-Head MDN Output
        self.mdn_heads = nn.ModuleList([
            nn.Linear(d_model, 3 * n_components) for _ in range(num_lines)
        ])
        
        # Importance weights for lines (exponential decay)
        weights = torch.tensor([0.9**i for i in range(num_lines)], dtype=torch.float)
        self.register_buffer("loss_weights", weights)

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.kaiming_normal_(module.weight, mode="fan_in", nonlinearity="relu")
            if module.bias is not None: nn.init.constant_(module.bias, 0)
        elif isinstance(module, nn.Conv1d):
            nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")

    def forward(self, X, src_key_padding_mask=None):
        """
        X: (B, T, input_dim)
        src_key_padding_mask: (B, T) boolean mask for padded elements in X
        """
        x = X["main"]
        
        # 1. CNN Feature Extraction
        # Input for Conv1d needs to be (B, C_in, L), so we permute
        x = x.permute(0, 2, 1)
        x = self.cnn_extractor(x)
        # Permute back to (B, T, C_out) for Transformer
        x = x.permute(0, 2, 1)

        # 2. Add Positional Encoding
        # Transformer expects (T, B, C), so permute again
        x = x.permute(1, 0, 2)
        x = self.pos_encoder(x)
        # Permute back to (B, T, C) for batch_first=True
        x = x.permute(1, 0, 2)

        # 3. Transformer Encoder
        # The mask should indicate which key values are NOT to be attended to
        encoded_seq = self.transformer_encoder(x, src_key_padding_mask=src_key_padding_mask)

        # We use the representation of the last valid timestep for prediction
        # (A common strategy, alternatively you could use mean pooling)
        # For simplicity, we'll take the last hidden state of the sequence.
        sequence_summary = encoded_seq[:, -1, :] # (B, d_model)
        
        # 4. Get parameters from all MDN heads
        mdn_params_list = []
        for i in range(self.num_lines):
            raw_params = self.mdn_heads[i](sequence_summary)
            pi, mu, sigma = mdn_split_params(raw_params, self.n_components)
            mdn_params_list.append({"pi": pi, "mu": mu, "sigma": sigma})

        return mdn_params_list

    def training_step(self, batch, batch_idx):
        X, y, lengths = batch
        # Create the padding mask for the transformer
        # True values indicate positions that should be ignored.
        max_len = X['main'].shape[1]
        mask = torch.arange(max_len, device=self.device)[None, :] >= lengths[:, None]

        mdn_params = self(X, src_key_padding_mask=mask)
        # Use a padding value of -1 for the loss function
        loss = weighted_mdn_nll_multihead(y, mdn_params, self.loss_weights, padding_value=-1)
        self.log("train/loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        X, y, lengths = batch
        max_len = X['main'].shape[1]
        mask = torch.arange(max_len, device=self.device)[None, :] >= lengths[:, None]
        
        mdn_params = self(X, src_key_padding_mask=mask)
        loss = weighted_mdn_nll_multihead(y, mdn_params, self.loss_weights, padding_value=-1)
        self.log("val/loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.lr)


# data manipulation

In [32]:
import pandas as pd
df_labels = pd.read_csv("/home/iatell/projects/meta-learning/data/line_seq_ordered.csv")
cols = [f'price_line{i}' for i in range(1, 10)]
df_labels = df_labels.dropna(subset=cols, how='all')
df_labels = df_labels.rename(columns={c: c.replace('price_line', 'linePrice_') 
                        for c in df_labels.columns if c.startswith('price_line')})
df_labels.to_csv("/home/iatell/projects/meta-learning/data/line_seq_ordered.csv", index=False)      
#     # overwrites the old file
df_labels

Unnamed: 0,startTime,endTime,startIndex,endIndex,linePrice_1,linePrice_2,linePrice_3,linePrice_4,linePrice_5,linePrice_6,linePrice_7,linePrice_8,linePrice_9
0,1514764800,1515110400,0,4,,0.878016,0.788209,,,,,,
1,1514764800,1515283200,0,6,,1.055290,0.923251,0.828937,,,,,
2,1515024000,1515369600,3,7,1.143628,,,,,,,,
3,1515456000,1514937600,2,8,1.139775,,,,,,,,
4,1515110400,1515542400,4,9,1.143279,0.964469,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...
328,1651795200,1649116800,1555,1586,0.873150,0.825739,0.905267,0.938913,,,,0.955736,
330,1652054400,1652227200,1589,1591,1.063729,,,1.023085,,,,,
331,1652572800,1651881600,1587,1595,0.813907,0.870793,,,,,,0.788406,0.904141
332,1653264000,1652227200,1591,1603,1.042211,1.075683,0.992004,0.958532,,,,,


# train

## simple

In [None]:
import sys
from pathlib import Path

# Current notebook location
notebook_path = Path().resolve()

# Add parent folder (meta/) to sys.path
sys.path.append(str(notebook_path.parent))
import joblib
import torch
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from datetime import datetime
from preprocess.multi_regression_seq_dif import preprocess_sequences_csv_multilines
# from models.LSTM.lstm_multi_line_reg_seq_dif import LSTMMultiRegressor
from utils.make_step import make_step
from utils.padding_batch_reg import collate_batch
from utils.get_init_argumens import get_init_args
import pandas as pd
import io
import numpy as np
import os
from add_ons.drop_column import drop_columns
from add_ons.normalize_candle_seq import add_label_normalized_candles
from add_ons.feature_pipeline3 import FeaturePipeline
from add_ons.candle_dif_rate_of_change_percentage import add_candle_rocp
from add_ons.candle_rate_of_change import add_candle_ratios
from sklearn.metrics import accuracy_score, f1_score
# ---------------- Evaluation ---------------- #
@torch.no_grad()
def evaluate_model_mdn(model, val_loader, zero_idx=0, threshold=0.1):
    """
    Evaluate CNN–LSTM–MDN model (last-output version).

    Args
    ----
    model : pl.LightningModule with MDN forward
    val_loader : DataLoader yielding (X, y, lengths)
    zero_idx : which mixture component is considered "no-line" (usually 0)
    threshold : if pi[:,zero_idx] > threshold → predict invalid

    Returns
    -------
    dict with mse, mae, acc, f1
    """
    model.eval()
    all_preds_reg, all_labels_reg = [], []
    all_preds_len, all_labels_len = [], []

    device = next(model.parameters()).device

    with torch.no_grad():
        for X_batch, y_batch, lengths in val_loader:
            if isinstance(X_batch, dict):
                X_batch = {k: v.to(device) for k, v in X_batch.items()}
            else:
                X_batch = X_batch.to(device)

            y_batch = y_batch.to(device)
            mdn = model(X_batch, lengths)
            pi, mu, sigma = mdn["pi"], mdn["mu"], mdn["sigma"]  # (B,K)

            # regression expectation
            y_pred = (pi * mu).sum(dim=-1)  # (B,)
            B = y_batch.size(0)
            y_len = (y_batch > 0).sum(dim=1)                # (B,)
            idx = torch.clamp(y_len - 1, min=0)             # last valid index
            y_true = y_batch[torch.arange(B, device=y_batch.device), idx]  # (B,)
            # only last step
            # print("lengths(features):", lengths[:10])
            # print("lengths(labels):", y_len[:10])

            all_preds_reg.append(y_pred.cpu().numpy())
            all_labels_reg.append(y_true.cpu().numpy())

            # validity classification
            pi_zero = pi[:, zero_idx]  # (B,)
            pred_valid = (pi_zero < (1 - threshold)).long()
            true_valid = torch.ones_like(pred_valid)  # last step always valid

            all_preds_len.extend(pred_valid.cpu().numpy().tolist())
            all_labels_len.extend(true_valid.cpu().numpy().tolist())


        # ----- Regression metrics -----
    all_preds_reg = np.concatenate(all_preds_reg)  # (N,)
    all_labels_reg = np.concatenate(all_labels_reg)
    mse = ((all_preds_reg - all_labels_reg) ** 2).mean()
    mae = np.abs(all_preds_reg - all_labels_reg).mean()
    # ----- Validity metrics -----
    acc = accuracy_score(all_labels_len, all_preds_len)
    f1 = f1_score(all_labels_len, all_preds_len, average="macro")

    print("\n📊 Validation Metrics (MDN, last-output):")
    print(f"  Regression → MSE: {mse:.6f}, MAE: {mae:.6f}")
    print(f"  Validity   → Acc: {acc:.4f}, F1: {f1:.4f}")

    return {"mse": mse, "mae": mae, "acc": acc, "f1": f1}
# ---------------- Train ---------------- #
def train_model(
    data_csv,
    labels_csv,
    model_out_dir="models/saved_models",
    do_validation=True,
    hidden_dim=200,
    num_layers=1,
    lr=0.001,
    batch_size=32,
    max_epochs=1000,
    save_model=False,
    return_val_accuracy = True,
    test_mode = False,
    early_stop = False
):

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    model_out = f"{model_out_dir}/lstm_model_multireg_{timestamp}.pt"
    meta_out  = f"{model_out_dir}/lstm_meta_multireg_{timestamp}.pkl"

    pipeline = FeaturePipeline(
        steps=[
            # make_step(add_label_normalized_candles),
            make_step(add_candle_rocp),
            make_step(drop_columns, cols_to_drop=["open","high","low","close","volume"]),
            
        ],
        # norm_methods={
        #     "main": {
        #         "upper_shadow": "robust", "body": "standard", "lower_shadow": "standard",
        #         "upper_body_ratio": "standard", "lower_body_ratio": "standard",
        #         "upper_lower_body_ratio": "standard", "Candle_Color": "standard"
        #     }
        # },
        per_window_flags=[
            False, 
          False, 
        #   True
                ]
    )
    # Preprocess: pad linePrices and sequences
    if do_validation:
        train_ds, val_ds, df, feature_cols, max_len_y = preprocess_sequences_csv_multilines(
            data_csv, labels_csv,
            val_split=True,
            for_xgboost=False,
            debug_sample=True,
            feature_pipeline=pipeline
        )
    else:
        train_ds, df, feature_cols, max_len_y = preprocess_sequences_csv_multilines(
            data_csv, labels_csv,
            val_split=False,
            for_xgboost=False,
            debug_sample=False,
            feature_pipeline=pipeline
        )
        val_ds = None

    sample = train_ds[0][0]  # first sample's features
    if isinstance(sample, dict):  # multiple feature groups
        input_dim = sample['main'].shape[1]
    else:  # single tensor
        input_dim = sample.shape[1]

    model = CNNLSTM_MDN(
        input_dim=input_dim,
        hidden_dim=hidden_dim,
        num_layers=num_layers,
        lr=lr
    )
    init_args = get_init_args(model, input_dim=input_dim, hidden_dim=hidden_dim, num_layers=num_layers, lr=lr)

    model_class_info = {
        "module": model.__class__.__module__,
        "class": model.__class__.__name__,
        "init_args": init_args
    }

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)
    val_loader = DataLoader(val_ds, batch_size=batch_size, collate_fn=collate_batch) if val_ds else None
    
    # --- Debug / Test mode --- #
    if test_mode:
        save_model = False
        from itertools import islice

        # Try to grab 3rd batch; if not available, take first
        try:
            batch = next(islice(iter(train_loader), 2, 3))
        except StopIteration:
            batch = next(iter(train_loader))

        X_batch_dict, y_batch, lengths = batch

        print("🔍 Debug batch:")
        if isinstance(X_batch_dict, dict):
            print("  Keys in X_batch:", list(X_batch_dict.keys()))
        print("  y_batch shape:", y_batch.shape)
        print("  First label in batch:", y_batch[0])

        # --- Track real column names for each feature group ---
        feature_names_dict = {}
        for name, X_batch in X_batch_dict.items():
            if name == "main":
                # Use actual feature columns after preprocessing
                feature_names_dict[name] = feature_cols
            else:
                # For extra feature groups, fallback to generic names
                feature_names_dict[name] = [f"{name}_{i}" for i in range(X_batch.shape[2])]

        dfs = []
        for name, X_batch in X_batch_dict.items():
            print(f"\nFeature group: {name}")
            print("  X_batch shape:", X_batch.shape)
            print("  First sequence in batch (first  steps):\n", X_batch[0][:])

            batch_size_, seq_len, feature_dim = X_batch.shape
            df_part = pd.DataFrame(
                X_batch.reshape(batch_size_ * seq_len, feature_dim).numpy(),
                columns=feature_names_dict[name]
            )
            dfs.append(df_part)

        # Combine all feature groups horizontally
        global df_seq
        df_seq = pd.concat(dfs, axis=1)
        print("\n✅ Combined df_seq shape:", df_seq.shape)
        print("✅ Column names in df_seq:", df_seq.columns.tolist())

    # --- Early stopping --- #
    if early_stop == True:
        from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
        early_stop_callback = EarlyStopping(
            monitor="val_loss",   # metric to monitor (must be logged in your LightningModule)
            patience=10,          # number of epochs with no improvement before stopping
            min_delta=0.001,      # minimum improvement to qualify as "better"
            mode="min",           # "min" for loss, "max" for accuracy
            verbose=True
        )

        checkpoint_callback = ModelCheckpoint(
            dirpath=model_out_dir,
            filename="best_model",
            save_top_k=1,
            monitor="val_loss",
            mode="min"
        )
        callbacks=[early_stop_callback,checkpoint_callback]

    trainer = pl.Trainer(
        max_epochs=max_epochs,
        accelerator="auto",
        devices=1,
        fast_dev_run=test_mode,
        gradient_clip_val=1.0,
        gradient_clip_algorithm="norm",
        callbacks= callbacks if early_stop else None
    )

    trainer.fit(model, train_loader, val_loader)

    if save_model:
        os.makedirs(model_out_dir, exist_ok=True)
        trainer.save_checkpoint(model_out)
        joblib.dump({
    "input_dim": input_dim,
    "hidden_dim": hidden_dim,
    "num_layers": num_layers,
    "max_len_y": max_len_y,
    "feature_cols": feature_cols,
    "scalers": pipeline.scalers,
    "pipeline_config": pipeline.export_config(),
    "model_class_info": model_class_info   # ✅ save model class info
}, meta_out)
        
    # --- Evaluation --- #
    if do_validation:
        mse, mae, acc, f1 = evaluate_model_mdn(model, val_loader)
        if return_val_accuracy:
            return {"mse": mse, "mae": mae, "acc": acc, "f1": f1}
        
if __name__ == "__main__":
    train_model(
        "/home/iatell/projects/meta-learning/data/Bitcoin_BTCUSDT_kaggle_1D_candles.csv",
        "/home/iatell/projects/meta-learning/data/seq_line_labels.csv",
        save_model=True,
        do_validation=True,
        test_mode = False
    )


## ordered

### cnn lstm

In [3]:
import sys
from pathlib import Path

# Current notebook location
notebook_path = Path().resolve()

# Add parent folder (meta/) to sys.path
sys.path.append(str(notebook_path.parent))
import joblib
import torch
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from datetime import datetime
from preprocess.multi_regression_seq_dif3 import preprocess_sequences_csv_multilines
# from models.LSTM.lstm_multi_line_reg_seq_dif import LSTMMultiRegressor
from utils.make_step import make_step
from utils.padding_batch_reg import collate_batch
from utils.get_init_argumens import get_init_args
import pandas as pd
import io
import numpy as np
import os
from add_ons.drop_columns2 import drop_columns
from add_ons.normalize_candle_seq import add_label_normalized_candles
from add_ons.feature_pipeline5 import FeaturePipeline
from add_ons.candle_dif_rate_of_change_percentage2 import add_candle_rocp
from add_ons.candle_rate_of_change import add_candle_ratios
from add_ons.candle_proportion_simple import add_candle_shape_features
from sklearn.metrics import accuracy_score, f1_score,mean_squared_error,mean_absolute_error
from utils.to_address import to_address
# ---------------- Evaluation ---------------- #
import numpy as np
from sklearn.metrics import mean_squared_error, mean_absolute_error, accuracy_score, f1_score
import torch

@torch.no_grad()
def evaluate_model_mdn(model, val_loader, threshold=0.1):
    """
    Evaluate CNN–LSTM–MDN model (multi-head, top-pi selection per line).

    Args
    ----
    model : pl.LightningModule with multi-head MDN forward
    val_loader : DataLoader yielding (X, y, lengths)
    threshold : optional threshold for validity classification

    Returns
    -------
    dict with mse, mae, acc, f1
    """
    model.eval()
    all_preds_reg, all_labels_reg = [], []
    all_preds_len, all_labels_len = [], []

    device = next(model.parameters()).device

    for X_batch, y_batch, lengths in val_loader:
        # Move to device
        if isinstance(X_batch, dict):
            X_batch = {k: v.to(device) for k, v in X_batch.items()}
        else:
            X_batch = X_batch.to(device)
        y_batch = y_batch.to(device)

        # Forward pass
        mdn_params = model(X_batch, lengths)

        B, num_lines = y_batch.shape
        y_pred_lines = []

        for i in range(num_lines):
            pi, mu = mdn_params['pi'], mdn_params['mu']  # both (B, n_components)
            
            # Pick component with highest pi per sample
            top_idx = torch.argmax(pi, dim=1, keepdim=True)     # (B,1)
            selected_mu = mu.gather(1, top_idx).squeeze(1)     # (B,)

            # Mask padded targets
            mask = (y_batch[:, i] != 0)
            selected_mu[~mask] = 0.0

            y_pred_lines.append(selected_mu)

        y_pred_all = torch.stack(y_pred_lines, dim=1)  # (B, num_lines)

        # Last valid step per sample
        y_len = (y_batch > 0).sum(dim=1)
        idx = torch.clamp(y_len - 1, min=0)
        y_true = y_batch[torch.arange(B), idx]
        y_pred = y_pred_all[torch.arange(B), idx]

        all_preds_reg.append(y_pred.cpu().numpy())
        all_labels_reg.append(y_true.cpu().numpy())

        # --- Validity classification ---
        pred_valid_lines = []
        for i in range(num_lines):
            pi = mdn_params['pi']    # (B, n_components)
            top_idx = torch.argmax(pi, dim=1, keepdim=True)
            pi_max = pi.gather(1, top_idx).squeeze(1)
            pred_valid_lines.append((pi_max > threshold).long())

        pred_valid_all = torch.stack(pred_valid_lines, dim=1)
        pred_valid_last = pred_valid_all[torch.arange(B), idx]
        true_valid_last = torch.ones_like(pred_valid_last)

        all_preds_len.extend(pred_valid_last.cpu().numpy().tolist())
        all_labels_len.extend(true_valid_last.cpu().numpy().tolist())

    # Concatenate all batches
    y_pred_reg = np.concatenate(all_preds_reg)
    y_true_reg = np.concatenate(all_labels_reg)

    mse = mean_squared_error(y_true_reg, y_pred_reg)
    mae = mean_absolute_error(y_true_reg, y_pred_reg)
    acc = accuracy_score(all_labels_len, all_preds_len)
    f1 = f1_score(all_labels_len, all_preds_len)

    print("mse:", mse, "mae:", mae, "acc:", acc, "f1:", f1)
    return {"mse": mse, "mae": mae, "acc": acc, "f1": f1}

# ---------------- Train ---------------- #
def train_model(
    data_csv,
    labels_csv,
    model_out_dir="models/saved_models",
    do_validation=True,
    hidden_dim=32,
    num_layers=1,
    lr=0.001,
    feature_eng=15,
    n_components=9,
    dropout = 0.1,
    batch_size=32,
    max_epochs=2,
    save_model=False,
    return_val_accuracy = True,
    test_mode = False,
    early_stop = False
):

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    model_out = f"{model_out_dir}/lstm_model_multireg_{timestamp}.pt"
    meta_out  = f"{model_out_dir}/lstm_meta_multireg_{timestamp}.pkl"

    pipeline = FeaturePipeline(
        steps=[
            # make_step(add_label_normalized_candles),
            make_step(add_candle_rocp),
            make_step(add_candle_shape_features,seperatable = "complete"),
            make_step(drop_columns, cols_to_drop=["open","high","low","close","volume"]),
            
        ],
        norm_methods={
            "main": {
                "upper_shadow": "standard", "body": "standard", "lower_shadow": "standard",
                # "open_dif":"standard","close_dif":"standard","high_dif":"standard","low_dif":"standard"
                # "upper_body_ratio": "standard", "lower_body_ratio": "standard",
                # "upper_lower_body_ratio": "standard", "Candle_Color": "standard"
            }
        },
        per_window_flags=[
            False, 
          False, 
          True
                ]
    )
    # Preprocess: pad linePrices and sequences

    if do_validation:
        train_ds, val_ds, df, feature_cols, max_len_y = preprocess_sequences_csv_multilines(
            data_csv, labels_csv,
            val_split=True,
            for_xgboost=False,
            debug_sample=True,
            feature_pipeline=pipeline,
            preserve_order= True
        )
    else:
        train_ds, df, feature_cols, max_len_y = preprocess_sequences_csv_multilines(
            data_csv, labels_csv,
            val_split=False,
            for_xgboost=False,
            debug_sample=False,
            feature_pipeline=pipeline,
            preserve_order= True
        )
        val_ds = None
    print("features",feature_cols)
    sample = train_ds[0][0]  # first sample's features
    if isinstance(sample, dict):  # multiple feature groups
        input_dim = sample['main'].shape[1]
    else:  # single tensor
        input_dim = sample.shape[1]

    model = cnn_lstm(input_dim=input_dim, feature_eng= feature_eng, hidden_dim=hidden_dim, 
                     n_components=n_components,  lr=lr, dropout=dropout,num_lines=max_len_y)
    init_args = get_init_args(model, input_dim=input_dim,feature_eng= feature_eng
                              ,hidden_dim=hidden_dim, n_components=n_components,
                              lr=lr, dropout=dropout,num_lines=max_len_y)
    model_class_info = {
        "module": model.__class__.__module__,
        "class": model.__class__.__name__,
        "init_args": init_args
    }

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)
    val_loader = DataLoader(val_ds, batch_size=batch_size, collate_fn=collate_batch) if val_ds else None
    
    # --- Debug / Test mode --- #
    if test_mode:
        save_model = False
        from itertools import islice

        # Try to grab 3rd batch; if not available, take first
        try:
            batch = next(islice(iter(train_loader), 2, 3))
        except StopIteration:
            batch = next(iter(train_loader))

        X_batch_dict, y_batch, lengths = batch

        print("🔍 Debug batch:")
        if isinstance(X_batch_dict, dict):
            print("  Keys in X_batch:", list(X_batch_dict.keys()))
        print("  y_batch shape:", y_batch.shape)
        print("  First label in batch:", y_batch[0])

        # --- Track real column names for each feature group ---
        feature_names_dict = {}
        for name, X_batch in X_batch_dict.items():
            if name == "main":
                # Use actual feature columns after preprocessing
                feature_names_dict[name] = feature_cols
            else:
                # For extra feature groups, fallback to generic names
                feature_names_dict[name] = [f"{name}_{i}" for i in range(X_batch.shape[2])]

        dfs = []
        for name, X_batch in X_batch_dict.items():
            print(f"\nFeature group: {name}")
            print("  X_batch shape:", X_batch.shape)
            print("  First sequence in batch (first  steps):\n", X_batch[0][:])

            batch_size_, seq_len, feature_dim = X_batch.shape
            df_part = pd.DataFrame(
                X_batch.reshape(batch_size_ * seq_len, feature_dim).numpy(),
                columns=feature_names_dict[name]
            )
            dfs.append(df_part)

        # Combine all feature groups horizontally
        global df_seq
        df_seq = pd.concat(dfs, axis=1)
        print("\n✅ Combined df_seq shape:", df_seq.shape)
        print("✅ Column names in df_seq:", df_seq.columns.tolist())

    # --- Early stopping --- #
    if early_stop == True:
        from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
        early_stop_callback = EarlyStopping(
            monitor="val_loss",   # metric to monitor (must be logged in your LightningModule)
            patience=10,          # number of epochs with no improvement before stopping
            min_delta=0.001,      # minimum improvement to qualify as "better"
            mode="min",           # "min" for loss, "max" for accuracy
            verbose=True
        )

        checkpoint_callback = ModelCheckpoint(
            dirpath=model_out_dir,
            filename="best_model",
            save_top_k=1,
            monitor="val_loss",
            mode="min"
        )
        callbacks=[early_stop_callback,checkpoint_callback]

    trainer = pl.Trainer(
        max_epochs=max_epochs,
        accelerator="auto",
        devices=1,
        fast_dev_run=test_mode,
        gradient_clip_val=1.0,
        gradient_clip_algorithm="norm",
        callbacks= callbacks if early_stop else None
    )

    trainer.fit(model, train_loader, val_loader)

    if save_model:
        os.makedirs(model_out_dir, exist_ok=True)
        trainer.save_checkpoint(model_out)
        joblib.dump({
    "input_dim": input_dim,
    "hidden_dim": hidden_dim,
    "num_layers": num_layers,
    "max_len_y": max_len_y,
    "feature_cols": feature_cols,
    "scalers": pipeline.scalers,
    "pipeline_config": pipeline.export_config(),
    "model_class_info": model_class_info   # ✅ save model class info
}, meta_out)
        
    # --- Evaluation --- #
    if do_validation:
        mse, mae, acc, f1 = evaluate_model_mdn(model, val_loader)
        if return_val_accuracy:
            return {"mse": mse, "mae": mae, "acc": acc, "f1": f1}
        
if __name__ == "__main__":
    train_model(
        "/home/iatell/projects/meta-learning/data/Bitcoin_BTCUSDT_kaggle_1D_candles.csv",
        "/home/iatell/projects/meta-learning/data/line_seq_ordered.csv",
        save_model=True,
        do_validation=True,
        # test_mode = True
    )


💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


features ['open_dif', 'high_dif', 'low_dif', 'close_dif']


2025-09-12 16:07:13.851383: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-09-12 16:07:14.092198: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1757680634.176325   39011 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1757680634.202129   39011 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1757680631.927969   39011 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/iatell/envs/Rllib2.43/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
/home/iatell/envs/Rllib2.43/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
/home/iatell/envs/Rllib2.43/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:310: The number of training batches (8) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=2` reached.


mse: 0.14302027225494385 mae: 0.2753087878227234 acc: 1.0 f1: 1.0


### cnn transforemer

In [None]:
import sys
from pathlib import Path

# Current notebook location
notebook_path = Path().resolve()

# Add parent folder (meta/) to sys.path
sys.path.append(str(notebook_path.parent))
import joblib
import torch
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from datetime import datetime
from preprocess.multi_regression_seq_dif2 import preprocess_sequences_csv_multilines
# from models.LSTM.lstm_multi_line_reg_seq_dif import LSTMMultiRegressor
from utils.make_step import make_step
from utils.padding_batch_reg import collate_batch
from utils.get_init_argumens import get_init_args
import pandas as pd
import io
import numpy as np
import os
from add_ons.drop_column import drop_columns
from add_ons.normalize_candle_seq import add_label_normalized_candles
from add_ons.feature_pipeline4 import FeaturePipeline
from add_ons.candle_dif_rate_of_change_percentage2 import add_candle_rocp
from add_ons.candle_rate_of_change import add_candle_ratios
from sklearn.metrics import accuracy_score, f1_score,mean_squared_error,mean_absolute_error
from utils.to_address import to_address
# ---------------- Evaluation ---------------- #
import numpy as np
from sklearn.metrics import mean_squared_error, mean_absolute_error, accuracy_score, f1_score
import torch

@torch.no_grad()
def evaluate_model_mdn(model, val_loader, threshold=0.1):
    """
    Evaluate CNN–LSTM–MDN model (multi-head, top-pi selection per line).

    Args
    ----
    model : pl.LightningModule with multi-head MDN forward
    val_loader : DataLoader yielding (X, y, lengths)
    threshold : optional threshold for validity classification

    Returns
    -------
    dict with mse, mae, acc, f1
    """
    model.eval()
    all_preds_reg, all_labels_reg = [], []
    all_preds_len, all_labels_len = [], []

    device = next(model.parameters()).device

    for X_batch, y_batch, lengths in val_loader:
        # Move to device
        if isinstance(X_batch, dict):
            X_batch = {k: v.to(device) for k, v in X_batch.items()}
        else:
            X_batch = X_batch.to(device)
        y_batch = y_batch.to(device)

        # Forward pass
        mdn_params = model(X_batch, lengths)

        B, num_lines = y_batch.shape
        y_pred_lines = []

        for i in range(num_lines):
            pi, mu = mdn_params['pi'], mdn_params['mu']  # both (B, n_components)
            
            # Pick component with highest pi per sample
            top_idx = torch.argmax(pi, dim=1, keepdim=True)     # (B,1)
            selected_mu = mu.gather(1, top_idx).squeeze(1)     # (B,)

            # Mask padded targets
            mask = (y_batch[:, i] != 0)
            selected_mu[~mask] = 0.0

            y_pred_lines.append(selected_mu)

        y_pred_all = torch.stack(y_pred_lines, dim=1)  # (B, num_lines)

        # Last valid step per sample
        y_len = (y_batch > 0).sum(dim=1)
        idx = torch.clamp(y_len - 1, min=0)
        y_true = y_batch[torch.arange(B), idx]
        y_pred = y_pred_all[torch.arange(B), idx]

        all_preds_reg.append(y_pred.cpu().numpy())
        all_labels_reg.append(y_true.cpu().numpy())

        # --- Validity classification ---
        pred_valid_lines = []
        for i in range(num_lines):
            pi = mdn_params['pi']    # (B, n_components)
            top_idx = torch.argmax(pi, dim=1, keepdim=True)
            pi_max = pi.gather(1, top_idx).squeeze(1)
            pred_valid_lines.append((pi_max > threshold).long())

        pred_valid_all = torch.stack(pred_valid_lines, dim=1)
        pred_valid_last = pred_valid_all[torch.arange(B), idx]
        true_valid_last = torch.ones_like(pred_valid_last)

        all_preds_len.extend(pred_valid_last.cpu().numpy().tolist())
        all_labels_len.extend(true_valid_last.cpu().numpy().tolist())

    # Concatenate all batches
    y_pred_reg = np.concatenate(all_preds_reg)
    y_true_reg = np.concatenate(all_labels_reg)

    mse = mean_squared_error(y_true_reg, y_pred_reg)
    mae = mean_absolute_error(y_true_reg, y_pred_reg)
    acc = accuracy_score(all_labels_len, all_preds_len)
    f1 = f1_score(all_labels_len, all_preds_len)

    print("mse:", mse, "mae:", mae, "acc:", acc, "f1:", f1)
    return {"mse": mse, "mae": mae, "acc": acc, "f1": f1}

# ---------------- Train ---------------- #
def train_model(
    data_csv,
    labels_csv,
    model_out_dir="models/saved_models",
    do_validation=True,
    hidden_dim=32,
    num_layers=1,
    lr=0.001,
    batch_size=32,
    max_epochs=500,
    save_model=False,
    return_val_accuracy = True,
    test_mode = False,
    early_stop = False
):

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    model_out = f"{model_out_dir}/lstm_model_multireg_{timestamp}.pt"
    meta_out  = f"{model_out_dir}/lstm_meta_multireg_{timestamp}.pkl"

    pipeline = FeaturePipeline(
        steps=[
            # make_step(add_label_normalized_candles),
            make_step(add_candle_rocp),
            make_step(drop_columns, cols_to_drop=["open","high","low","close","volume"]),
            
        ],
        # norm_methods={
        #     "main": {
        #         "upper_shadow": "robust", "body": "standard", "lower_shadow": "standard",
        #         "upper_body_ratio": "standard", "lower_body_ratio": "standard",
        #         "upper_lower_body_ratio": "standard", "Candle_Color": "standard"
        #     }
        # },
        per_window_flags=[
            False, 
          False, 
        #   True
                ]
    )
    # Preprocess: pad linePrices and sequences

    if do_validation:
        train_ds, val_ds, df, feature_cols, max_len_y = preprocess_sequences_csv_multilines(
            data_csv, labels_csv,
            val_split=True,
            for_xgboost=False,
            debug_sample=True,
            feature_pipeline=pipeline,
            preserve_order= True
        )
    else:
        train_ds, df, feature_cols, max_len_y = preprocess_sequences_csv_multilines(
            data_csv, labels_csv,
            val_split=False,
            for_xgboost=False,
            debug_sample=False,
            feature_pipeline=pipeline,
            preserve_order= True
        )
        val_ds = None

    sample = train_ds[0][0]  # first sample's features
    if isinstance(sample, dict):  # multiple feature groups
        input_dim = sample['main'].shape[1]
    else:  # single tensor
        input_dim = sample.shape[1]

    model = cnn_transformer(input_dim, feature_eng=15, hidden_dim=32, n_components=9, num_lines=9, lr=1e-3, dropout=0.1
    )
    init_args = get_init_args(model, input_dim=input_dim,num_lines= max_len_y )

    model_class_info = {
        "module": model.__class__.__module__,
        "class": model.__class__.__name__,
        "init_args": init_args
    }

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)
    val_loader = DataLoader(val_ds, batch_size=batch_size, collate_fn=collate_batch) if val_ds else None
    
    # --- Debug / Test mode --- #
    if test_mode:
        save_model = False
        from itertools import islice

        # Try to grab 3rd batch; if not available, take first
        try:
            batch = next(islice(iter(train_loader), 2, 3))
        except StopIteration:
            batch = next(iter(train_loader))

        X_batch_dict, y_batch, lengths = batch

        print("🔍 Debug batch:")
        if isinstance(X_batch_dict, dict):
            print("  Keys in X_batch:", list(X_batch_dict.keys()))
        print("  y_batch shape:", y_batch.shape)
        print("  First label in batch:", y_batch[0])

        # --- Track real column names for each feature group ---
        feature_names_dict = {}
        for name, X_batch in X_batch_dict.items():
            if name == "main":
                # Use actual feature columns after preprocessing
                feature_names_dict[name] = feature_cols
            else:
                # For extra feature groups, fallback to generic names
                feature_names_dict[name] = [f"{name}_{i}" for i in range(X_batch.shape[2])]

        dfs = []
        for name, X_batch in X_batch_dict.items():
            print(f"\nFeature group: {name}")
            print("  X_batch shape:", X_batch.shape)
            print("  First sequence in batch (first  steps):\n", X_batch[0][:])

            batch_size_, seq_len, feature_dim = X_batch.shape
            df_part = pd.DataFrame(
                X_batch.reshape(batch_size_ * seq_len, feature_dim).numpy(),
                columns=feature_names_dict[name]
            )
            dfs.append(df_part)

        # Combine all feature groups horizontally
        global df_seq
        df_seq = pd.concat(dfs, axis=1)
        print("\n✅ Combined df_seq shape:", df_seq.shape)
        print("✅ Column names in df_seq:", df_seq.columns.tolist())

    # --- Early stopping --- #
    if early_stop == True:
        from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
        early_stop_callback = EarlyStopping(
            monitor="val_loss",   # metric to monitor (must be logged in your LightningModule)
            patience=10,          # number of epochs with no improvement before stopping
            min_delta=0.001,      # minimum improvement to qualify as "better"
            mode="min",           # "min" for loss, "max" for accuracy
            verbose=True
        )

        checkpoint_callback = ModelCheckpoint(
            dirpath=model_out_dir,
            filename="best_model",
            save_top_k=1,
            monitor="val_loss",
            mode="min"
        )
        callbacks=[early_stop_callback,checkpoint_callback]

    trainer = pl.Trainer(
        max_epochs=max_epochs,
        accelerator="auto",
        devices=1,
        fast_dev_run=test_mode,
        gradient_clip_val=1.0,
        gradient_clip_algorithm="norm",
        callbacks= callbacks if early_stop else None
    )

    trainer.fit(model, train_loader, val_loader)

    if save_model:
        os.makedirs(model_out_dir, exist_ok=True)
        trainer.save_checkpoint(model_out)
        joblib.dump({
    "input_dim": input_dim,
    "hidden_dim": hidden_dim,
    "num_layers": num_layers,
    "max_len_y": max_len_y,
    "feature_cols": feature_cols,
    "scalers": pipeline.scalers,
    "pipeline_config": pipeline.export_config(),
    "model_class_info": model_class_info   # ✅ save model class info
}, meta_out)
        
    # --- Evaluation --- #
    if do_validation:
        mse, mae, acc, f1 = evaluate_model_mdn(model, val_loader)
        if return_val_accuracy:
            return {"mse": mse, "mae": mae, "acc": acc, "f1": f1}
        
if __name__ == "__main__":
    train_model(
        "/home/iatell/projects/meta-learning/data/Bitcoin_BTCUSDT_kaggle_1D_candles.csv",
        "/home/iatell/projects/meta-learning/data/line_seq_ordered.csv",
        save_model=True,
        do_validation=True,
        test_mode = False
    )


💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs



=== DEBUG SAMPLE CHECK (Torch mode) ===

--- Sequence 0 ---
Label: [1.143628 0.       0.       0.       0.       0.       0.       0.
 0.      ] Encoded (padded): [1.143628 0.       0.       0.       0.       0.       0.       0.
 0.      ]
Shape: (5, 4)
First few rows of sequence:
 [[ 0.01562355 -0.00180042 -0.01639293  0.0093857 ]
 [ 0.00938704  0.12409948  0.04899828  0.12622231]
 [ 0.12622082 -0.00192766  0.09665822  0.00645032]
 [ 0.00645032 -0.00251821 -0.02505807 -0.05388233]
 [-0.04985064 -0.0454773  -0.17924407 -0.07724382]]



You are using a CUDA device ('NVIDIA GeForce RTX 3060') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
2025-09-11 22:54:43.322679: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-09-11 22:54:43.587133: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1757618683.667056     774 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin 

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/iatell/envs/Rllib2.43/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
  output = torch._nested_tensor_from_mask(
/home/iatell/envs/Rllib2.43/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
/home/iatell/envs/Rllib2.43/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:310: The number of training batches (8) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=500` reached.


AssertionError: only bool and floating types of src_key_padding_mask are supported

## two head lstm

In [16]:
import torch
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report, confusion_matrix
from datetime import datetime
from preprocess.multi_regression_seq_dif2 import preprocess_sequences_csv_multilines
# from models.LSTM.lstm_multi_line_reg_seq_dif import LSTMMultiRegressor
from utils.print_batch import print_batch
from utils.to_address import to_address
from utils.json_to_csv import json_to_csv_in_memory
from utils.padding_batch_reg import collate_batch
import pandas as pd
import io
import numpy as np
import os
from sklearn.metrics import accuracy_score, f1_score
from add_ons.feature_pipeline4 import FeaturePipeline
from add_ons.drop_column import drop_columns
from add_ons.candle_dif_rate_of_change_percentage2 import add_candle_rocp
from add_ons.candle_proportion import add_candle_proportions
from add_ons.candle_rate_of_change import add_candle_ratios
from utils.make_step import make_step

# ---------------- Evaluation ---------------- #
def evaluate_model(model, val_loader, threshold=0.5):
    model.eval()
    all_preds_reg, all_labels_reg = [], []
    all_preds_len, all_labels_len = [], []

    with torch.no_grad():
        for X_batch, y_batch, lengths in val_loader:
            # Send to same device as model
            device = next(model.parameters()).device
            X_batch = {k: v.to(device) for k, v in X_batch.items()}
            y_batch = y_batch.to(device)
            lengths = lengths.to(device)

            # Forward pass: regression + length logits
            y_pred, len_logits = model(X_batch, lengths)

            # Regression targets
            all_preds_reg.append(y_pred.cpu().numpy())
            all_labels_reg.append(y_batch.cpu().numpy())

            # Length targets
            true_lengths = lengths.cpu().numpy()
            pred_lengths = model.predict_length(len_logits).cpu().numpy()

            all_labels_len.extend(true_lengths.tolist())
            all_preds_len.extend(pred_lengths.tolist())

    # ----- Regression metrics -----
    all_preds_reg = np.vstack(all_preds_reg)
    all_labels_reg = np.vstack(all_labels_reg)

    mse = ((all_preds_reg - all_labels_reg) ** 2).mean()
    mae = np.abs(all_preds_reg - all_labels_reg).mean()

    # ----- Length metrics -----


    acc = accuracy_score(all_labels_len, all_preds_len)
    f1 = f1_score(all_labels_len, all_preds_len, average="macro")

    print("\n📊 Validation Metrics:")
    print(f"  Regression → MSE: {mse:.6f}, MAE: {mae:.6f}")
    print(f"  Length     → Acc: {acc:.4f}, F1: {f1:.4f}")

    return {"mse": mse, "mae": mae, "acc": acc, "f1": f1}


# ---------------- Train ---------------- #
def train_model(
    data_csv,
    labels_csv,
    model_out_dir="models/saved_models",
    do_validation=True,
    hidden_dim=128,
    num_layers=1,
    lr=0.001,
    batch_size=32,
    max_epochs=50,
    save_model=True,
    return_val_accuracy = True,
    test_mode = True,
    early_stop = False
):

    pipeline = FeaturePipeline(
        steps=[
            # make_step(add_label_normalized_candles),
            make_step(add_candle_rocp),
            make_step(drop_columns, cols_to_drop=["open","high","low","close","volume"]),
            
        ],
        # norm_methods={
        #     "main": {
        #         "upper_shadow": "robust", "body": "standard", "lower_shadow": "standard",
        #         "upper_body_ratio": "standard", "lower_body_ratio": "standard",
        #         "upper_lower_body_ratio": "standard", "Candle_Color": "standard"
        #     }
        # },
        per_window_flags=[
            False, 
          False, 
        #   True
                ]
    )
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    model_out = f"{model_out_dir}/lstm_model_multireg_multihead_{timestamp}.pt"
    meta_out  = f"{model_out_dir}/lstm_meta_multireg_multihead_{timestamp}.pkl"

    # Preprocess: pad linePrices and sequences
    if do_validation:
        train_ds, val_ds, df, feature_cols, max_len_y = preprocess_sequences_csv_multilines(
            data_csv, labels_csv,
            val_split=True,
            for_xgboost=False,
            debug_sample=True,
            feature_pipeline=pipeline
        )
    else:
        train_ds, df, feature_cols, max_len_y = preprocess_sequences_csv_multilines(
            data_csv, labels_csv,
            val_split=False,
            for_xgboost=False,
            debug_sample=False
        )
        val_ds = None

    sample = train_ds[0][0]  # first sample's features
    if isinstance(sample, dict):  # multiple feature groups
        input_dim = sample['main'].shape[1]
    else:  # single tensor
        input_dim = sample.shape[1]

    model = LSTMMultiRegressor(
        input_dim=input_dim,
        hidden_dim=hidden_dim,
        num_layers=num_layers,
        max_len_y=max_len_y,
        lr=lr
    )
    init_args = {
    "input_dim": input_dim,
    "hidden_dim": hidden_dim,
    "num_layers": num_layers,
    "max_len_y": max_len_y,
    "lr": lr
}

    model_class_info = {
        "module": model.__class__.__module__,
        "class": model.__class__.__name__,
        "init_args": init_args
    }

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)
    val_loader = DataLoader(val_ds, batch_size=batch_size, collate_fn=collate_batch) if val_ds else None
    # --- Early stopping --- #
    if early_stop == True:
        from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
        early_stop_callback = EarlyStopping(
            monitor="val_loss",   # metric to monitor (must be logged in your LightningModule)
            patience=10,          # number of epochs with no improvement before stopping
            min_delta=0.001,      # minimum improvement to qualify as "better"
            mode="min",           # "min" for loss, "max" for accuracy
            verbose=True
        )

        checkpoint_callback = ModelCheckpoint(
            dirpath=model_out_dir,
            filename="best_model",
            save_top_k=1,
            monitor="val_loss",
            mode="min"
        )
        callbacks=[early_stop_callback,checkpoint_callback]

    trainer = pl.Trainer(
        max_epochs=max_epochs,
        accelerator="auto",
        devices=1,
        fast_dev_run=test_mode,
        gradient_clip_val=1.0,
        gradient_clip_algorithm="norm",
        callbacks= callbacks if early_stop else None
    )

    trainer.fit(model, train_loader, val_loader)

    # --- Debug / Test mode --- #
    if test_mode:
        save_model = False
        from itertools import islice

        # Try to grab 3rd batch; if not available, take first
        try:
            batch = next(islice(iter(train_loader), 2, 3))
        except StopIteration:
            batch = next(iter(train_loader))

        X_batch_dict, y_batch, lengths = batch

        print("🔍 Debug batch:")
        if isinstance(X_batch_dict, dict):
            print("  Keys in X_batch:", list(X_batch_dict.keys()))
        print("  y_batch shape:", y_batch.shape)
        print("  First label in batch:", y_batch[0])

        # --- Track real column names for each feature group ---
        feature_names_dict = {}
        for name, X_batch in X_batch_dict.items():
            if name == "main":
                # Use actual feature columns after preprocessing
                feature_names_dict[name] = feature_cols
            else:
                # For extra feature groups, fallback to generic names
                feature_names_dict[name] = [f"{name}_{i}" for i in range(X_batch.shape[2])]

        dfs = []
        for name, X_batch in X_batch_dict.items():
            print(f"\nFeature group: {name}")
            print("  X_batch shape:", X_batch.shape)
            print("  First sequence in batch (first  steps):\n", X_batch[0][:])

            batch_size_, seq_len, feature_dim = X_batch.shape
            df_part = pd.DataFrame(
                X_batch.reshape(batch_size_ * seq_len, feature_dim).numpy(),
                columns=feature_names_dict[name]
            )
            dfs.append(df_part)

        # Combine all feature groups horizontally
        global df_seq
        df_seq = pd.concat(dfs, axis=1)
        print("\n✅ Combined df_seq shape:", df_seq.shape)
        print("✅ Column names in df_seq:", df_seq.columns.tolist())

        
    if save_model:
        os.makedirs(model_out_dir, exist_ok=True)
        trainer.save_checkpoint(model_out)
        joblib.dump({
            "input_dim": input_dim,
            "hidden_dim": hidden_dim,
            "num_layers": num_layers,
            "max_len_y": max_len_y,
            "feature_cols": feature_cols,
            "scalers": pipeline.scalers,
            "pipeline_config": pipeline.export_config(),
            "model_class_info": model_class_info 
        }, meta_out)
        print(f"✅ Model saved to {model_out}")
        print(f"✅ Meta saved to {meta_out}")


        
    # --- Evaluation --- #
    if do_validation:
        mse, mae, acc, f1 = evaluate_model(model, val_loader)
        if return_val_accuracy:
            return {"mse": mse, "mae": mae, "acc": acc, "f1": f1}
        
if __name__ == "__main__":
    train_model(
        "/home/iatell/projects/meta-learning/data/Bitcoin_BTCUSDT_kaggle_1D_candles.csv",
        "/home/iatell/projects/meta-learning/data/seq_line_labels.csv",
        do_validation=True,
        test_mode = False
    )


💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/iatell/envs/Rllib2.43/lib/python3.11/site-packages/pytorch_lightning/trainer/configuration_validator.py:68: You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type              | Params | Mode 
----------------------------------------------------------
0 | lstm        | LSTM              | 68.6 K | train
1 | fc_reg      | Linear            | 774    | train
2 | fc_len      | Linear            | 774    | train
3 | loss_fn_reg | MSELoss           | 0      | train
4 | loss_fn_len | BCEWithLogitsLoss | 0      | train
----------------------------------------------------------
70.


=== DEBUG SAMPLE CHECK (Torch mode) ===

--- Sequence 0 ---
Label: [1.086008 1.126277 1.165107 0.970955 0.       0.      ] Encoded (padded): [1.086008 1.126277 1.165107 0.970955 0.       0.      ]
Shape: (5, 4)
First few rows of sequence:
 [[ 0.00645032 -0.00251821 -0.02505807 -0.05388233]
 [-0.04985064 -0.0454773  -0.17924407 -0.07724382]
 [-0.08115927 -0.05037893  0.09358804 -0.03372177]
 [-0.03365467 -0.03511871 -0.06278902  0.03521458]
 [ 0.03742796  0.00087057 -0.13184595 -0.11191386]]



/home/iatell/envs/Rllib2.43/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
/home/iatell/envs/Rllib2.43/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:310: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=50` reached.


✅ Model saved to models/saved_models/lstm_model_multireg_multihead_20250910_154556.pt
✅ Meta saved to models/saved_models/lstm_meta_multireg_multihead_20250910_154556.pkl

📊 Validation Metrics:
  Regression → MSE: 0.439017, MAE: 0.504375
  Length     → Acc: 0.0667, F1: 0.0096


## xgboost two head

In [None]:
import sys
from pathlib import Path

# Current notebook location
notebook_path = Path().resolve()

# Add parent folder (meta/) to sys.path
sys.path.append(str(notebook_path.parent))
import joblib
import joblib
from datetime import datetime
import xgboost as xgb
from sklearn.multioutput import MultiOutputRegressor
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import os
import io
import numpy as np
import pandas as pd
import warnings
from sklearn.model_selection import train_test_split
from utils.make_step import make_step
from preprocess.multi_regression_seq_dif2 import preprocess_sequences_csv_multilines
from add_ons.drop_column import drop_columns
from add_ons.feature_pipeline4 import FeaturePipeline
from add_ons.normalize_candle_seq import add_label_normalized_candles
from add_ons.candle_dif_rate_of_change_percentage2 import add_candle_rocp
from add_ons.candle_rate_of_change import add_candle_ratios
# ---------------- Evaluation ---------------- #
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import numpy as np
import warnings

def evaluate_model(model, length_model, X_val, y_val, true_lengths, return_sequences=False):
    """
    Evaluate multi-output regression with predicted sequence lengths.
    Permutation-invariant: sorts both predictions and true values before computing metrics.
    Can optionally return the predicted vs true sequences for inspection.
    """
    y_pred_full = model.predict(X_val)
    pred_lengths = np.round(length_model.predict(X_val)).astype(int)

    print("\n📊 Validation Report (Multi-Regression with variable-length sequences):")
    mse_list, mae_list, r2_list = [], [], []

    pred_vs_true_list = []  # store predicted vs true sequences if needed

    for i, (pred, pred_len, true_y, true_len) in enumerate(zip(y_pred_full, pred_lengths, y_val, true_lengths)):
        L = min(pred_len, true_len)
        pred_trunc = np.sort(pred[:L])       # sort predictions for permutation-invariant metrics
        true_trunc = np.sort(true_y[:L])     # sort true values

        mse = mean_squared_error(true_trunc, pred_trunc)
        mae = mean_absolute_error(true_trunc, pred_trunc)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            try:
                r2 = r2_score(true_trunc, pred_trunc)
            except ValueError:
                r2 = np.nan

        mse_list.append(mse)
        mae_list.append(mae)
        r2_list.append(r2)

        print(f"\nSample {i}:")
        print(f"  Predicted length: {pred_len}, True length: {true_len}")
        print(f"  MSE: {mse:.6f}, MAE: {mae:.6f}, R²: {r2:.6f}")
        print(f"  Predicted lines: {pred_trunc}")
        print(f"  True lines     : {true_trunc}")

        if return_sequences:
            pred_vs_true_list.append((pred_trunc, true_trunc))

    print("\n--- Global Scores ---")
    print(f"Mean MSE: {np.mean(mse_list):.6f}")
    print(f"Mean MAE: {np.mean(mae_list):.6f}")
    print(f"Mean R²: {np.nanmean(r2_list):.6f}")

    results = {"mse": np.mean(mse_list), "mae": np.mean(mae_list), "r2": np.nanmean(r2_list)}
    
    if return_sequences:
        results["pred_vs_true"] = pred_vs_true_list
    
    return results

# ---------------- Train ---------------- #
def train_model_xgb_multireg(
    data_csv,
    labels_csv,
    model_out_dir="models/saved_models",
    do_validation=True,
    n_estimators=1000,
    max_depth=16,
    learning_rate=0.05,
    subsample=0.8,
    colsample_bytree=0.8,
    save_model=False,
    return_val_metrics=True,
    **model_params
):
    """
    Train a multi-output XGBoost regressor with a linked sequence-length predictor.
    """
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    model_out = f"{model_out_dir}/xgb_model_multireg_{timestamp}.pkl"
    length_model_out = f"{model_out_dir}/xgb_model_seq_len_{timestamp}.pkl"
    meta_out = f"{model_out_dir}/xgb_meta_multireg_{timestamp}.pkl"

    pipeline = FeaturePipeline(
        steps=[
            # make_step(add_label_normalized_candles),
            make_step(add_candle_rocp),
            make_step(drop_columns, cols_to_drop=["open","high","low","close","volume"]),
            
        ],
        # norm_methods={
        #     "main": {
        #         "upper_shadow": "robust", "body": "standard", "lower_shadow": "standard",
        #         "upper_body_ratio": "standard", "lower_body_ratio": "standard",
        #         "upper_lower_body_ratio": "standard", "Candle_Color": "standard"
        #     }
        # },
        per_window_flags=[
            False, 
          False, 
        #   True
                ]
    )
    # --- Preprocess data ---
    if do_validation:
        X_train, y_train, X_val, y_val, df, feature_cols, max_len_y, seq_lengths_true = preprocess_sequences_csv_multilines(
            data_csv, labels_csv,
            val_split=True,
            for_xgboost=True,
            debug_sample=True,
            feature_pipeline=pipeline
        )
    else:
        X_train, y_train, df, feature_cols, max_len_y, seq_lengths_true = preprocess_sequences_csv_multilines(
            data_csv, labels_csv,
            val_split=False,
            for_xgboost=True,
            feature_pipeline=pipeline
        )
        X_val, y_val = None, None


    # --- Sequence length targets ---
    if do_validation:
        idx_train, idx_val = train_test_split(
            np.arange(len(seq_lengths_true)),
            test_size=0.2,  # match your preprocess split
            random_state=42
        )
        train_lengths = np.array(seq_lengths_true)[idx_train]
        val_lengths   = np.array(seq_lengths_true)[idx_val]
    else:
        train_lengths = np.array(seq_lengths_true)

    # --- Train max-line regression ---
    xgb_model = xgb.XGBRegressor(
        n_estimators=n_estimators,
        max_depth=max_depth,
        learning_rate=learning_rate,
        subsample=subsample,
        colsample_bytree=colsample_bytree,
        objective="reg:squarederror",
        **model_params
    )
    model = MultiOutputRegressor(xgb_model, n_jobs=-1)
    model.fit(X_train, y_train)

    # --- Train length predictor ---
    xgb_len_model = xgb.XGBRegressor(
        n_estimators=n_estimators,
        max_depth=max_depth,
        learning_rate=learning_rate,
        subsample=subsample,
        colsample_bytree=colsample_bytree,
        objective="reg:squarederror",
        **model_params
    )
    xgb_len_model.fit(X_train, train_lengths)


    # --- Save models ---
    if save_model:
        os.makedirs(model_out_dir, exist_ok=True)
        
        # Save trained models
        joblib.dump(model, model_out)
        joblib.dump(xgb_len_model, length_model_out)
        
        # Save full metadata
        meta_dict = {
            "feature_cols": feature_cols,
            "target_dim": max_len_y,
            "n_estimators": n_estimators,
            "max_depth": max_depth,
            "learning_rate": learning_rate,
            "subsample": subsample,
            "colsample_bytree": colsample_bytree,
            "model_params": model_params,
            "scalers": pipeline.scalers,
            "pipeline_config": pipeline.export_config(),
            "multioutput_wrapper": {
                "class": model.__class__.__name__,
                "module": model.__class__.__module__,
            }
        }
        joblib.dump(meta_dict, meta_out)
        
        print(f"✅ Model saved to {model_out}")
        print(f"✅ Length predictor saved to {length_model_out}")
        print(f"✅ Metadata saved to {meta_out}")
    # --- Evaluate ---
    val_metrics = None
    if do_validation:
        metrics = evaluate_model(model, xgb_len_model, X_val, y_val, val_lengths, return_sequences=True)


    if return_val_metrics:
        return val_metrics

# ---------------- Main ---------------- #
if __name__ == "__main__":
    train_model_xgb_multireg(
        "/home/iatell/projects/meta-learning/data/Bitcoin_BTCUSDT_kaggle_1D_candles.csv",
        "/home/iatell/projects/meta-learning/data/line_seq_ordered.csv",
        do_validation=True,
        save_model=False
    )


In [30]:
df_seq = df_seq.loc[~(df_seq==0).all(axis=1)]
df_seq

Unnamed: 0,open_dif,high_dif,low_dif,close_dif,upper_shadow,lower_shadow,body,color
0,-0.006806,-0.001821,0.006829,-0.000681,0.007334,0.010904,0.000881,0.3
1,-0.001094,-0.005757,0.002917,-0.003194,0.002631,0.004840,0.002981,0.3
2,-0.002852,-0.000483,-0.016987,-0.004921,0.005014,0.016906,0.005050,0.3
3,-0.005208,-0.007726,0.004962,-0.006193,0.002470,0.005871,0.006035,0.3
4,-0.006068,-0.001653,-0.002766,-0.003026,0.006924,0.005611,0.002993,0.3
...,...,...,...,...,...,...,...,...
1212,-0.041315,0.060291,0.021794,0.106569,0.004192,0.000254,0.096806,0.7
1213,0.107181,0.076916,0.090062,0.058212,0.021941,0.015712,0.055011,0.7
1214,0.058297,0.010164,0.039187,0.015989,0.016082,0.033486,0.015658,0.7
1215,0.015517,0.072815,0.039727,0.086572,0.003218,0.010444,0.080029,0.7


# server

## MDN server

### cnn lstm

In [None]:
import sys
from pathlib import Path

# Current notebook location
notebook_path = Path().resolve()

# Add parent folder (meta/) to sys.path
sys.path.append(str(notebook_path.parent))
import glob
import joblib
import torch
import numpy as np
import pandas as pd
from flask import Flask, request, jsonify, render_template
from servers.pre_process.multi_reg_dif_seq2 import ServerPreprocess, import_class, build_pipeline_from_config
# from models.LSTM.cnn_lstm_mdn import CNNLSTM_MDN  # <-- your updated "last-output" model

app = Flask(__name__)

# ---------------- Load model and meta ----------------
meta_path = glob.glob("/home/iatell/projects/meta-learning/play_grounds/models/saved_models/lstm_meta_multireg_*.pkl")[0]
state_path = glob.glob("/home/iatell/projects/meta-learning/play_grounds/models/saved_models/lstm_model_multireg*.pt")[0]

meta = joblib.load(meta_path)
FEATURES = meta['feature_cols']
print("features",FEATURES)
# ---------------- Model ----------------
# Reconstruct model class
#for python file:
# model_cls_info = meta["model_class_info"]
# ModelClass = import_class(model_cls_info["module"], model_cls_info["class"])
model_cls_info = meta["model_class_info"]
ModelClass = cnn_lstm
# Initialize model with original args
model = ModelClass(**model_cls_info["init_args"])
model = cnn_lstm.load_from_checkpoint(state_path)
model.eval()

# ---------------- Load data ----------------
df = pd.read_csv( "/home/iatell/projects/meta-learning/data/Bitcoin_BTCUSDT_kaggle_1D_candles.csv", parse_dates=['timestamp'])

# ---------------- Setup pipeline ----------------
pipeline = build_pipeline_from_config(meta["pipeline_config"])
pipeline.scalers = meta["scalers"]

# Stateful preprocessing instance
preproc = ServerPreprocess(feature_pipeline=pipeline)


# ---------------- Routes ----------------
@app.route("/")
def home():
    return render_template("sequential.html")


@app.route("/get_and_add_data")
def get_and_add_data():
    dense = df.set_index('timestamp').asfreq('D').ffill()
    initial_seq_len = 21
    next_idx = request.args.get("idx", type=int)
    if next_idx is None:
        # First call → load initial candles
        if len(preproc.dataset) == 0:
            for _, row in dense.iloc[:initial_seq_len].iterrows():
                preproc.add_candle(row)
        candles = [
            {'time': int(ts.timestamp()),
             'open': float(row.open),
             'high': float(row.high),
             'low': float(row.low),
             'close': float(row.close)}
            for ts, row in dense.iloc[:initial_seq_len].iterrows()
        ]
        print("Returning initial candles:", candles)

        return jsonify({
            "initial_seq_len": initial_seq_len,
            "next_idx": initial_seq_len,
            "candles": candles
        })
    else:
        # Subsequent calls → 1 candle
        if next_idx >= len(dense):
            print("Reached end of data at index:", next_idx)
            return jsonify({"error": "End of data"}), 404

        row = dense.iloc[next_idx]
        candle = {
            'time': int(row.name.timestamp()),
            'open': float(row.open),
            'high': float(row.high),
            'low': float(row.low),
            'close': float(row.close)
        }

        # ✅ Add to preproc automatically
        preproc.add_candle(row)

        return jsonify({
            "next_idx": next_idx + 1,
            "candle": candle
        })


@app.route("/predict", methods=['POST'])
def predict():
    data = request.get_json(force=True)
    seq_len = data.get("seq_len")

    if not seq_len or not isinstance(seq_len, int):
        return jsonify({"error": "Provide 'seq_len' as an int"}), 400

    try:
        # prepare subsequence from current state
        seq_dict = preproc.prepare_seq(seq_len)  # returns dict of DataFrames
    except ValueError as e:
        return jsonify({"error": str(e)}), 400

    # Convert dict of DataFrames to dict of tensors
    dict_x = {k: torch.from_numpy(v.values.astype(np.float32)).unsqueeze(0)
            for k, v in seq_dict.items()}


    with torch.no_grad():
        mdn_out = model(dict_x)

    pi    = mdn_out['pi'][0].cpu().numpy()
    mu    = mdn_out['mu'][0].cpu().numpy()
    sigma = mdn_out['sigma'][0].cpu().numpy()
    last_close = preproc.reference_dataset.iloc[-1]['close']

    return jsonify({
        'pred_prices': (last_close * mu).tolist(),
        'pred_sigmas': (last_close * sigma).tolist(),
        'pi': pi.tolist()
    })


if __name__ == '__main__':
    app.run(debug=True, use_reloader=False)


features ['open_dif', 'high_dif', 'low_dif', 'close_dif']
 * Serving Flask app '__main__'
 * Debug mode: on


 * Running on http://127.0.0.1:5000
Press CTRL+C to quit
127.0.0.1 - - [12/Sep/2025 17:53:33] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [12/Sep/2025 17:53:34] "GET /get_and_add_data?init=1 HTTP/1.1" 200 -


Returning initial candles: [{'time': 1514764800, 'open': 13707.91, 'high': 13818.55, 'low': 12750.0, 'close': 13380.0}, {'time': 1514851200, 'open': 13382.16, 'high': 15473.49, 'low': 12890.02, 'close': 14675.11}, {'time': 1514937600, 'open': 14690.0, 'high': 15307.56, 'low': 14150.0, 'close': 14919.51}, {'time': 1515024000, 'open': 14919.51, 'high': 15280.0, 'low': 13918.04, 'close': 15059.54}, {'time': 1515110400, 'open': 15059.56, 'high': 17176.24, 'low': 14600.0, 'close': 16960.39}, {'time': 1515196800, 'open': 16960.39, 'high': 17143.13, 'low': 16011.21, 'close': 17069.79}, {'time': 1515283200, 'open': 17069.79, 'high': 17099.96, 'low': 15610.0, 'close': 16150.03}, {'time': 1515369600, 'open': 16218.85, 'high': 16322.3, 'low': 12812.0, 'close': 14902.54}, {'time': 1515456000, 'open': 14902.54, 'high': 15500.0, 'low': 14011.05, 'close': 14400.0}, {'time': 1515542400, 'open': 14401.0, 'high': 14955.66, 'low': 13131.31, 'close': 14907.09}, {'time': 1515628800, 'open': 14940.0, 'high'

127.0.0.1 - - [12/Sep/2025 17:53:34] "GET /get_and_add_data?idx=21 HTTP/1.1" 200 -
127.0.0.1 - - [12/Sep/2025 17:53:34] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [12/Sep/2025 17:53:34] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [12/Sep/2025 17:53:34] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [12/Sep/2025 17:53:34] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [12/Sep/2025 17:53:34] "POST /predict HTTP/1.1" 200 -



=== Final seq_dict keys and shapes ===
main: shape (3, 4), columns: ['open_dif', 'high_dif', 'low_dif', 'close_dif']
add_candle_shape_features: shape (3, 4), columns: ['upper_shadow', 'lower_shadow', 'body', 'color']

=== Final seq_dict keys and shapes ===
main: shape (5, 4), columns: ['open_dif', 'high_dif', 'low_dif', 'close_dif']
add_candle_shape_features: shape (5, 4), columns: ['upper_shadow', 'lower_shadow', 'body', 'color']

=== Final seq_dict keys and shapes ===
main: shape (8, 4), columns: ['open_dif', 'high_dif', 'low_dif', 'close_dif']
add_candle_shape_features: shape (8, 4), columns: ['upper_shadow', 'lower_shadow', 'body', 'color']

=== Final seq_dict keys and shapes ===
main: shape (13, 4), columns: ['open_dif', 'high_dif', 'low_dif', 'close_dif']
add_candle_shape_features: shape (13, 4), columns: ['upper_shadow', 'lower_shadow', 'body', 'color']

=== Final seq_dict keys and shapes ===
main: shape (20, 4), columns: ['open_dif', 'high_dif', 'low_dif', 'close_dif']
add_can

127.0.0.1 - - [12/Sep/2025 17:53:38] "GET /get_and_add_data?idx=22 HTTP/1.1" 200 -
127.0.0.1 - - [12/Sep/2025 17:53:38] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [12/Sep/2025 17:53:38] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [12/Sep/2025 17:53:38] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [12/Sep/2025 17:53:38] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [12/Sep/2025 17:53:38] "POST /predict HTTP/1.1" 200 -



=== Final seq_dict keys and shapes ===
main: shape (3, 4), columns: ['open_dif', 'high_dif', 'low_dif', 'close_dif']
add_candle_shape_features: shape (3, 4), columns: ['upper_shadow', 'lower_shadow', 'body', 'color']

=== Final seq_dict keys and shapes ===
main: shape (5, 4), columns: ['open_dif', 'high_dif', 'low_dif', 'close_dif']
add_candle_shape_features: shape (5, 4), columns: ['upper_shadow', 'lower_shadow', 'body', 'color']

=== Final seq_dict keys and shapes ===
main: shape (8, 4), columns: ['open_dif', 'high_dif', 'low_dif', 'close_dif']
add_candle_shape_features: shape (8, 4), columns: ['upper_shadow', 'lower_shadow', 'body', 'color']

=== Final seq_dict keys and shapes ===
main: shape (20, 4), columns: ['open_dif', 'high_dif', 'low_dif', 'close_dif']
add_candle_shape_features: shape (20, 4), columns: ['upper_shadow', 'lower_shadow', 'body', 'color']

=== Final seq_dict keys and shapes ===
main: shape (13, 4), columns: ['open_dif', 'high_dif', 'low_dif', 'close_dif']
add_can

127.0.0.1 - - [12/Sep/2025 17:53:39] "GET /get_and_add_data?idx=23 HTTP/1.1" 200 -
127.0.0.1 - - [12/Sep/2025 17:53:39] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [12/Sep/2025 17:53:39] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [12/Sep/2025 17:53:39] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [12/Sep/2025 17:53:39] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [12/Sep/2025 17:53:39] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [12/Sep/2025 17:53:39] "GET /get_and_add_data?idx=24 HTTP/1.1" 200 -



=== Final seq_dict keys and shapes ===
main: shape (3, 4), columns: ['open_dif', 'high_dif', 'low_dif', 'close_dif']
add_candle_shape_features: shape (3, 4), columns: ['upper_shadow', 'lower_shadow', 'body', 'color']

=== Final seq_dict keys and shapes ===
main: shape (5, 4), columns: ['open_dif', 'high_dif', 'low_dif', 'close_dif']
add_candle_shape_features: shape (5, 4), columns: ['upper_shadow', 'lower_shadow', 'body', 'color']

=== Final seq_dict keys and shapes ===
main: shape (13, 4), columns: ['open_dif', 'high_dif', 'low_dif', 'close_dif']
add_candle_shape_features: shape (13, 4), columns: ['upper_shadow', 'lower_shadow', 'body', 'color']

=== Final seq_dict keys and shapes ===
main: shape (8, 4), columns: ['open_dif', 'high_dif', 'low_dif', 'close_dif']
add_candle_shape_features: shape (8, 4), columns: ['upper_shadow', 'lower_shadow', 'body', 'color']

=== Final seq_dict keys and shapes ===
main: shape (20, 4), columns: ['open_dif', 'high_dif', 'low_dif', 'close_dif']
add_can

127.0.0.1 - - [12/Sep/2025 17:53:39] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [12/Sep/2025 17:53:39] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [12/Sep/2025 17:53:39] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [12/Sep/2025 17:53:39] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [12/Sep/2025 17:53:39] "POST /predict HTTP/1.1" 200 -



=== Final seq_dict keys and shapes ===
main: shape (20, 4), columns: ['open_dif', 'high_dif', 'low_dif', 'close_dif']
add_candle_shape_features: shape (20, 4), columns: ['upper_shadow', 'lower_shadow', 'body', 'color']


127.0.0.1 - - [12/Sep/2025 17:53:39] "GET /get_and_add_data?idx=25 HTTP/1.1" 200 -
127.0.0.1 - - [12/Sep/2025 17:53:39] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [12/Sep/2025 17:53:39] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [12/Sep/2025 17:53:39] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [12/Sep/2025 17:53:39] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [12/Sep/2025 17:53:39] "POST /predict HTTP/1.1" 200 -



=== Final seq_dict keys and shapes ===
main: shape (3, 4), columns: ['open_dif', 'high_dif', 'low_dif', 'close_dif']
add_candle_shape_features: shape (3, 4), columns: ['upper_shadow', 'lower_shadow', 'body', 'color']

=== Final seq_dict keys and shapes ===
main: shape (5, 4), columns: ['open_dif', 'high_dif', 'low_dif', 'close_dif']
add_candle_shape_features: shape (5, 4), columns: ['upper_shadow', 'lower_shadow', 'body', 'color']

=== Final seq_dict keys and shapes ===
main: shape (8, 4), columns: ['open_dif', 'high_dif', 'low_dif', 'close_dif']
add_candle_shape_features: shape (8, 4), columns: ['upper_shadow', 'lower_shadow', 'body', 'color']

=== Final seq_dict keys and shapes ===
main: shape (13, 4), columns: ['open_dif', 'high_dif', 'low_dif', 'close_dif']
add_candle_shape_features: shape (13, 4), columns: ['upper_shadow', 'lower_shadow', 'body', 'color']

=== Final seq_dict keys and shapes ===
main: shape (20, 4), columns: ['open_dif', 'high_dif', 'low_dif', 'close_dif']
add_can

## lstm two head

In [3]:
import sys
from pathlib import Path

# Current notebook location
notebook_path = Path().resolve()

# Add parent folder (meta/) to sys.path
sys.path.append(str(notebook_path.parent))
from pathlib import Path
import glob
import joblib
import torch
import numpy as np
import pandas as pd
from flask import Flask, request, jsonify, render_template
from servers.pre_process.multi_reg_dif_seq import ServerPreprocess, import_class, build_pipeline_from_config
# from models.LSTM.two_head_lstm import LSTMMultiRegressor  # your new model

app = Flask(__name__)

# ---------------- Load model and meta ----------------
meta_path = glob.glob("/home/iatell/projects/meta-learning/play_grounds/models/saved_models/lstm_meta_multireg_multihead_*.pkl")[0]
state_path = glob.glob("/home/iatell/projects/meta-learning/play_grounds/models/saved_models/lstm_model_multireg_multihead_*.pt")[0]


meta = joblib.load(meta_path)
FEATURES = meta['feature_cols']
print("features", FEATURES)

# Initialize model class
model_cls_info = meta["model_class_info"]
init_args = model_cls_info["init_args"]
model = LSTMMultiRegressor.load_from_checkpoint(state_path, **init_args)
model.eval()


# ---------------- Load data ----------------
df = pd.read_csv("/home/iatell/projects/meta-learning/data/Bitcoin_BTCUSDT_kaggle_1D_candles.csv", parse_dates=['timestamp'])

# ---------------- Setup pipeline ----------------
pipeline = build_pipeline_from_config(meta["pipeline_config"])
pipeline.scalers = meta["scalers"]
preproc = ServerPreprocess(feature_pipeline=pipeline)

# ---------------- Routes ----------------
@app.route("/")
def home():
    return render_template("two_head.html")


@app.route("/get_and_add_data")
def get_and_add_data():
    dense = df.set_index('timestamp').asfreq('D').ffill()
    initial_seq_len = 21
    next_idx = request.args.get("idx", type=int)

    if next_idx is None:
        if len(preproc.dataset) == 0:
            for _, row in dense.iloc[:initial_seq_len].iterrows():
                preproc.add_candle(row)
        candles = [{'time': int(ts.timestamp()),
                    'open': float(row.open),
                    'high': float(row.high),
                    'low': float(row.low),
                    'close': float(row.close)}
                   for ts, row in dense.iloc[:initial_seq_len].iterrows()]
        return jsonify({
            "initial_seq_len": initial_seq_len,
            "next_idx": initial_seq_len,
            "candles": candles
        })
    else:
        if next_idx >= len(dense):
            return jsonify({"error": "End of data"}), 404
        row = dense.iloc[next_idx]
        candle = {'time': int(row.name.timestamp()),
                  'open': float(row.open),
                  'high': float(row.high),
                  'low': float(row.low),
                  'close': float(row.close)}
        preproc.add_candle(row)
        return jsonify({"next_idx": next_idx + 1, "candle": candle})


@app.route("/predict", methods=['POST'])
def predict():
    data = request.get_json(force=True)
    seq_len = data.get("seq_len")

    if not seq_len or not isinstance(seq_len, int):
        return jsonify({"error": "Provide 'seq_len' as an int"}), 400

    try:
        seq_df = preproc.prepare_seq(seq_len)
    except ValueError as e:
        return jsonify({"error": str(e)}), 400

    X_np = seq_df[FEATURES].values.astype(np.float32)
    X_t = torch.from_numpy(X_np).unsqueeze(0)
    lengths = torch.tensor([seq_len], dtype=torch.long)

    with torch.no_grad():
        y_pred, len_logits = model({"main": X_t}, lengths)

    last_close = preproc.reference_dataset.iloc[-1]['close']
    pred_prices = (last_close * y_pred[0]).tolist()
    pred_len = model.predict_length(len_logits).item()

    return jsonify({
        "pred_prices": pred_prices,
        "pred_len": pred_len
    })


if __name__ == "__main__":
    app.run(debug=True, use_reloader=False)


features ['open_dif', 'high_dif', 'low_dif', 'close_dif']
 * Serving Flask app '__main__'
 * Debug mode: on


 * Running on http://127.0.0.1:5000
Press CTRL+C to quit
127.0.0.1 - - [10/Sep/2025 16:07:57] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:07:57] "GET /get_and_add_data HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:07:58] "GET /get_and_add_data?idx=21 HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:07:58] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:07:58] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:07:58] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:07:58] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:07:58] "POST /predict HTTP/1.1" 200 -


len len tensor([[1.0439, 0.9929, 0.9968, 1.0580, 0.9667, 0.8891]])
tensor([[1.9650, 1.7708, 1.7662, 1.9512, 1.7988, 1.6682]])
len tensor([[2.9861, 2.6310, 2.6556, 2.9212, 2.7351, 2.5645]])
len tensor([[3.2939, 2.8888, 2.9717, 3.2023, 3.0238, 2.8708]])
len tensor([[3.2047, 2.8153, 2.9138, 3.1162, 2.9445, 2.8013]])


127.0.0.1 - - [10/Sep/2025 16:08:20] "GET /get_and_add_data?idx=22 HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:20] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:20] "POST /predict HTTP/1.1" 200 -


len len tensor([[1.9650, 1.7708, 1.7662, 1.9512, 1.7988, 1.6682]])
tensor([[1.0439, 0.9929, 0.9968, 1.0580, 0.9667, 0.8891]])


127.0.0.1 - - [10/Sep/2025 16:08:24] "GET /get_and_add_data?idx=23 HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:24] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:24] "POST /predict HTTP/1.1" 200 -


len len tensor([[1.0439, 0.9929, 0.9968, 1.0580, 0.9667, 0.8891]])
tensor([[1.9650, 1.7708, 1.7662, 1.9512, 1.7988, 1.6682]])


127.0.0.1 - - [10/Sep/2025 16:08:26] "GET /get_and_add_data?idx=24 HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:26] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:26] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:26] "GET /get_and_add_data?idx=25 HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:26] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:26] "POST /predict HTTP/1.1" 200 -


len len tensor([[1.0439, 0.9929, 0.9968, 1.0580, 0.9667, 0.8891]])
tensor([[1.9650, 1.7708, 1.7662, 1.9512, 1.7988, 1.6682]])
len len tensor([[1.0439, 0.9929, 0.9968, 1.0580, 0.9667, 0.8891]])
tensor([[1.9650, 1.7708, 1.7662, 1.9512, 1.7988, 1.6682]])


127.0.0.1 - - [10/Sep/2025 16:08:27] "GET /get_and_add_data?idx=26 HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:27] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:27] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:27] "GET /get_and_add_data?idx=27 HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:27] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:27] "POST /predict HTTP/1.1" 200 -


len len tensor([[1.0439, 0.9929, 0.9968, 1.0580, 0.9667, 0.8891]])
tensor([[1.9650, 1.7708, 1.7662, 1.9512, 1.7988, 1.6682]])
len len tensor([[1.0439, 0.9929, 0.9968, 1.0580, 0.9667, 0.8891]])
tensor([[1.9650, 1.7708, 1.7662, 1.9512, 1.7988, 1.6682]])


127.0.0.1 - - [10/Sep/2025 16:08:27] "GET /get_and_add_data?idx=28 HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:27] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:27] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:27] "GET /get_and_add_data?idx=29 HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:27] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:27] "POST /predict HTTP/1.1" 200 -


len len tensor([[1.0439, 0.9929, 0.9968, 1.0580, 0.9667, 0.8891]])
tensor([[1.9650, 1.7708, 1.7662, 1.9512, 1.7988, 1.6682]])
len tensor([[1.0439, 0.9929, 0.9968, 1.0580, 0.9667, 0.8891]])
len tensor([[1.9650, 1.7708, 1.7662, 1.9512, 1.7988, 1.6682]])


127.0.0.1 - - [10/Sep/2025 16:08:27] "GET /get_and_add_data?idx=30 HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:27] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:27] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:27] "GET /get_and_add_data?idx=31 HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:27] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:27] "POST /predict HTTP/1.1" 200 -


len len tensor([[1.0439, 0.9929, 0.9968, 1.0580, 0.9667, 0.8891]])
tensor([[1.9650, 1.7708, 1.7662, 1.9512, 1.7988, 1.6682]])
len tensor([[1.0439, 0.9929, 0.9968, 1.0580, 0.9667, 0.8891]])
len tensor([[1.9650, 1.7708, 1.7662, 1.9512, 1.7988, 1.6682]])


127.0.0.1 - - [10/Sep/2025 16:08:28] "GET /get_and_add_data?idx=32 HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:28] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:28] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:28] "GET /get_and_add_data?idx=33 HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:28] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:28] "POST /predict HTTP/1.1" 200 -


len len tensor([[1.9650, 1.7708, 1.7662, 1.9512, 1.7988, 1.6682]])
tensor([[1.0439, 0.9929, 0.9968, 1.0580, 0.9667, 0.8891]])
len len tensor([[1.9650, 1.7708, 1.7662, 1.9512, 1.7988, 1.6682]])
tensor([[1.0439, 0.9929, 0.9968, 1.0580, 0.9667, 0.8891]])


127.0.0.1 - - [10/Sep/2025 16:08:28] "GET /get_and_add_data?idx=34 HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:28] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:28] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:28] "GET /get_and_add_data?idx=35 HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:28] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:28] "POST /predict HTTP/1.1" 200 -


len len tensor([[1.0439, 0.9929, 0.9968, 1.0580, 0.9667, 0.8891]])
tensor([[1.9650, 1.7708, 1.7662, 1.9512, 1.7988, 1.6682]])
len tensor([[1.0439, 0.9929, 0.9968, 1.0580, 0.9667, 0.8891]])
len tensor([[1.9650, 1.7708, 1.7662, 1.9512, 1.7988, 1.6682]])


127.0.0.1 - - [10/Sep/2025 16:08:28] "GET /get_and_add_data?idx=36 HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:28] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:28] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:28] "GET /get_and_add_data?idx=37 HTTP/1.1" 200 -


len len tensor([[1.9650, 1.7708, 1.7662, 1.9512, 1.7988, 1.6682]])
tensor([[1.0439, 0.9929, 0.9968, 1.0580, 0.9667, 0.8891]])


127.0.0.1 - - [10/Sep/2025 16:08:28] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:28] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:29] "GET /get_and_add_data?idx=38 HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:29] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:29] "POST /predict HTTP/1.1" 200 -


len len tensor([[1.0439, 0.9929, 0.9968, 1.0580, 0.9667, 0.8891]])
tensor([[1.9650, 1.7708, 1.7662, 1.9512, 1.7988, 1.6682]])
len tensor([[1.0439, 0.9929, 0.9968, 1.0580, 0.9667, 0.8891]])
len tensor([[1.9650, 1.7708, 1.7662, 1.9512, 1.7988, 1.6682]])


127.0.0.1 - - [10/Sep/2025 16:08:29] "GET /get_and_add_data?idx=39 HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:29] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:29] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:29] "GET /get_and_add_data?idx=40 HTTP/1.1" 200 -


len tensor([[1.0439, 0.9929, 0.9968, 1.0580, 0.9667, 0.8891]])
len tensor([[1.9650, 1.7708, 1.7662, 1.9512, 1.7988, 1.6682]])


127.0.0.1 - - [10/Sep/2025 16:08:29] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:29] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:29] "GET /get_and_add_data?idx=41 HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:29] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:29] "POST /predict HTTP/1.1" 200 -


len tensor([[1.0439, 0.9929, 0.9968, 1.0580, 0.9667, 0.8891]])
len tensor([[1.9650, 1.7708, 1.7662, 1.9512, 1.7988, 1.6682]])
len len tensor([[1.0439, 0.9929, 0.9968, 1.0580, 0.9667, 0.8891]])
tensor([[1.9650, 1.7708, 1.7662, 1.9512, 1.7988, 1.6682]])


127.0.0.1 - - [10/Sep/2025 16:08:29] "GET /get_and_add_data?idx=42 HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:29] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:29] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:30] "GET /get_and_add_data?idx=43 HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:30] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:30] "POST /predict HTTP/1.1" 200 -


len len tensor([[1.0439, 0.9929, 0.9968, 1.0580, 0.9667, 0.8891]])
tensor([[1.9650, 1.7708, 1.7662, 1.9512, 1.7988, 1.6682]])
len tensor([[1.0439, 0.9929, 0.9968, 1.0580, 0.9667, 0.8891]])
len tensor([[1.9650, 1.7708, 1.7662, 1.9512, 1.7988, 1.6682]])


127.0.0.1 - - [10/Sep/2025 16:08:30] "GET /get_and_add_data?idx=44 HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:30] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:30] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:30] "GET /get_and_add_data?idx=45 HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:30] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:30] "POST /predict HTTP/1.1" 200 -


len len tensor([[1.0439, 0.9929, 0.9968, 1.0580, 0.9667, 0.8891]])
tensor([[1.9650, 1.7708, 1.7662, 1.9512, 1.7988, 1.6682]])
len tensor([[1.0439, 0.9929, 0.9968, 1.0580, 0.9667, 0.8891]])
len tensor([[1.9650, 1.7708, 1.7662, 1.9512, 1.7988, 1.6682]])


127.0.0.1 - - [10/Sep/2025 16:08:30] "GET /get_and_add_data?idx=46 HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:30] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:30] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:30] "GET /get_and_add_data?idx=47 HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:30] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:30] "POST /predict HTTP/1.1" 200 -


len len tensor([[1.9650, 1.7708, 1.7662, 1.9512, 1.7988, 1.6682]])
tensor([[1.0439, 0.9929, 0.9968, 1.0580, 0.9667, 0.8891]])
len len tensor([[1.0439, 0.9929, 0.9968, 1.0580, 0.9667, 0.8891]])
tensor([[1.9650, 1.7708, 1.7662, 1.9512, 1.7988, 1.6682]])


127.0.0.1 - - [10/Sep/2025 16:08:30] "GET /get_and_add_data?idx=48 HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:30] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:30] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:31] "GET /get_and_add_data?idx=49 HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:31] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:31] "POST /predict HTTP/1.1" 200 -


len len tensor([[1.0439, 0.9929, 0.9968, 1.0580, 0.9667, 0.8891]])
tensor([[1.9650, 1.7708, 1.7662, 1.9512, 1.7988, 1.6682]])
len tensor([[1.0439, 0.9929, 0.9968, 1.0580, 0.9667, 0.8891]])
len tensor([[1.9650, 1.7708, 1.7662, 1.9512, 1.7988, 1.6682]])


127.0.0.1 - - [10/Sep/2025 16:08:31] "GET /get_and_add_data?idx=50 HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:31] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:31] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:31] "GET /get_and_add_data?idx=51 HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:31] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:31] "POST /predict HTTP/1.1" 200 -


lenlen  tensor([[1.0439, 0.9929, 0.9968, 1.0580, 0.9667, 0.8891]])
tensor([[1.9650, 1.7708, 1.7662, 1.9512, 1.7988, 1.6682]])
len tensor([[1.0439, 0.9929, 0.9968, 1.0580, 0.9667, 0.8891]])
len tensor([[1.9650, 1.7708, 1.7662, 1.9512, 1.7988, 1.6682]])


127.0.0.1 - - [10/Sep/2025 16:08:31] "GET /get_and_add_data?idx=52 HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:31] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:31] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:31] "GET /get_and_add_data?idx=53 HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:31] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:31] "POST /predict HTTP/1.1" 200 -


len len tensor([[1.0439, 0.9929, 0.9968, 1.0580, 0.9667, 0.8891]])
tensor([[1.9650, 1.7708, 1.7662, 1.9512, 1.7988, 1.6682]])
len tensor([[1.0439, 0.9929, 0.9968, 1.0580, 0.9667, 0.8891]])
len tensor([[1.9650, 1.7708, 1.7662, 1.9512, 1.7988, 1.6682]])


127.0.0.1 - - [10/Sep/2025 16:08:31] "GET /get_and_add_data?idx=54 HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:31] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:31] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:32] "GET /get_and_add_data?idx=55 HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:32] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:32] "POST /predict HTTP/1.1" 200 -


len len tensor([[1.0439, 0.9929, 0.9968, 1.0580, 0.9667, 0.8891]])
tensor([[1.9650, 1.7708, 1.7662, 1.9512, 1.7988, 1.6682]])
len len tensor([[1.9650, 1.7708, 1.7662, 1.9512, 1.7988, 1.6682]])
tensor([[1.0439, 0.9929, 0.9968, 1.0580, 0.9667, 0.8891]])


127.0.0.1 - - [10/Sep/2025 16:08:32] "GET /get_and_add_data?idx=56 HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:32] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:32] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:32] "GET /get_and_add_data?idx=57 HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:32] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:32] "POST /predict HTTP/1.1" 200 -


len len tensor([[1.0439, 0.9929, 0.9968, 1.0580, 0.9667, 0.8891]])
tensor([[1.9650, 1.7708, 1.7662, 1.9512, 1.7988, 1.6682]])
len len tensor([[1.0439, 0.9929, 0.9968, 1.0580, 0.9667, 0.8891]])
tensor([[1.9650, 1.7708, 1.7662, 1.9512, 1.7988, 1.6682]])


127.0.0.1 - - [10/Sep/2025 16:08:32] "GET /get_and_add_data?idx=58 HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:32] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:32] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:32] "GET /get_and_add_data?idx=59 HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:32] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:32] "POST /predict HTTP/1.1" 200 -


len len tensor([[1.0439, 0.9929, 0.9968, 1.0580, 0.9667, 0.8891]])
tensor([[1.9650, 1.7708, 1.7662, 1.9512, 1.7988, 1.6682]])
len tensor([[1.0439, 0.9929, 0.9968, 1.0580, 0.9667, 0.8891]])
len tensor([[1.9650, 1.7708, 1.7662, 1.9512, 1.7988, 1.6682]])


127.0.0.1 - - [10/Sep/2025 16:08:32] "GET /get_and_add_data?idx=60 HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:32] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [10/Sep/2025 16:08:32] "POST /predict HTTP/1.1" 200 -


len len tensor([[1.9650, 1.7708, 1.7662, 1.9512, 1.7988, 1.6682]])
tensor([[1.0439, 0.9929, 0.9968, 1.0580, 0.9667, 0.8891]])


## xgboost two head

In [None]:
import sys
from pathlib import Path

# Current notebook location
notebook_path = Path().resolve()
sys.path.append(str(notebook_path.parent))

import glob
import joblib
import numpy as np
import pandas as pd
from flask import Flask, request, jsonify, render_template

from servers.pre_process.multi_reg_dif_seq import ServerPreprocess, build_pipeline_from_config

# ---------------- Flask ----------------
app = Flask(__name__)

# ---------------- Load models + meta ----------------
meta_path = glob.glob("/home/iatell/projects/meta-learning/play_grounds/models/saved_models/xgb_meta_multireg_*.pkl")[0]
model_path = glob.glob("/home/iatell/projects/meta-learning/play_grounds/models/saved_models/xgb_model_multireg_*.pkl")[0]
len_model_path = glob.glob("/home/iatell/projects/meta-learning/play_grounds/models/saved_models/xgb_model_seq_len_*.pkl")[0]

meta = joblib.load(meta_path)
FEATURES = meta['feature_cols']
print("features", FEATURES)

# Models
model = joblib.load(model_path)       # MultiOutputRegressor with XGBRegressor inside
len_model = joblib.load(len_model_path)

# ---------------- Load data ----------------
df = pd.read_csv("/home/iatell/projects/meta-learning/data/Bitcoin_BTCUSDT_kaggle_1D_candles.csv", parse_dates=['timestamp'])

# ---------------- Setup pipeline ----------------
pipeline = build_pipeline_from_config(meta["pipeline_config"])
pipeline.scalers = meta["scalers"]

# Stateful preprocessing instance
preproc = ServerPreprocess(feature_pipeline=pipeline)


# ---------------- Routes ----------------
@app.route("/")
def home():
    return render_template("xgboost_seq.html")


@app.route("/get_and_add_data")
def get_and_add_data():
    dense = df.set_index('timestamp').asfreq('D').ffill()
    initial_seq_len = 21
    next_idx = request.args.get("idx", type=int)

    if next_idx is None:
        # First call → load initial candles
        if len(preproc.dataset) == 0:
            for _, row in dense.iloc[:initial_seq_len].iterrows():
                preproc.add_candle(row)

        candles = [
            {'time': int(ts.timestamp()),
             'open': float(row.open),
             'high': float(row.high),
             'low': float(row.low),
             'close': float(row.close)}
            for ts, row in dense.iloc[:initial_seq_len].iterrows()
        ]
        return jsonify({
            "initial_seq_len": initial_seq_len,
            "next_idx": initial_seq_len,
            "candles": candles
        })
    else:
        # Subsequent calls → 1 candle
        if next_idx >= len(dense):
            return jsonify({"error": "End of data"}), 404

        row = dense.iloc[next_idx]
        candle = {
            'time': int(row.name.timestamp()),
            'open': float(row.open),
            'high': float(row.high),
            'low': float(row.low),
            'close': float(row.close)
        }

        preproc.add_candle(row)

        return jsonify({
            "next_idx": next_idx + 1,
            "candle": candle
        })


@app.route("/predict", methods=['POST'])
def predict():
    data = request.get_json(force=True)
    seq_len = data.get("seq_len")

    if not seq_len or not isinstance(seq_len, int):
        return jsonify({"error": "Provide 'seq_len' as an int"}), 400

    try:
        # Use your XGBoost + preproc logic
        X_np = preproc.prepare_xgboost_seq(seq_len, model=len_model)
        pred_len = int(np.round(len_model.predict(X_np))[0])
        y_pred_full = model.predict(X_np)[0]
        pred_trunc = np.sort(y_pred_full[:pred_len])
        last_close = preproc.reference_dataset.iloc[-1]['close']
        pred_scaled = (last_close * pred_trunc).tolist()

        return jsonify({
            'pred_length': pred_len,
            'pred_lines': pred_scaled
        })
    except Exception as e:
        # <-- This will print the actual exception in the console
        import traceback
        traceback.print_exc()
        return jsonify({"error": str(e)}), 500





if __name__ == '__main__':
    app.run(debug=True, use_reloader=False)


# Tensorboard

In [5]:
import os
import subprocess
import webbrowser

logdir = "lightning_logs"

# 1. Find all version folders
versions = [d for d in os.listdir(logdir) if d.startswith("version_")]
if not versions:
    raise ValueError("No version folders found in lightning_logs")

# 2. Sort numerically and get the latest
versions.sort(key=lambda x: int(x.split("_")[1]))
latest_version = versions[-1]
latest_logdir = os.path.join(logdir, latest_version)
print(f"Launching TensorBoard for: {latest_logdir}")

# 3. Choose a port
port = 6006

# 4. Launch TensorBoard as a background process
subprocess.Popen(["tensorboard", f"--logdir={latest_logdir}", f"--port={port}"])

# 5. Open TensorBoard in default browser
webbrowser.open(f"http://localhost:{port}")


Launching TensorBoard for: lightning_logs/version_120


2025-09-10 16:16:54.726542: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-09-10 16:16:54.736657: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1757508414.749226    4333 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1757508414.752893    4333 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1757508414.762872    4333 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

True

gio: http://localhost:6006: Operation not supported

NOTE: Using experimental fast data loading logic. To disable, pass
    "--load_fast=false" and report issues on GitHub. More details:
    https://github.com/tensorflow/tensorboard/issues/4784

E0910 16:16:57.287815 126933973667968 program.py:300] TensorBoard could not bind to port 6006, it was already in use
ERROR: TensorBoard could not bind to port 6006, it was already in use
