<a href="https://colab.research.google.com/github/DaniilMz/OMM/blob/main/Competition.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import onnx
import onnx.helper as helper

from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
from sklearn.model_selection import KFold

from typing import Tuple, List, Optional, Dict, Any

In [None]:
train_data = pd.read_parquet("./competition_package/datasets/train.parquet")
eval_data = pd.read_parquet("./competition_package/datasets/valid.parquet")

train_data[['t0', 't1']] = train_data[['t0', 't1']] / 100
eval_data[['t0', 't1']] = eval_data[['t0', 't1']] / 100

In [None]:
p_features = [col for col in train_data.columns if col.startswith('p')]
v_features = [col for col in train_data.columns if col.startswith('v')]
dp_features = [col for col in train_data.columns if col.startswith('dp')]
dv_features = [col for col in train_data.columns if col.startswith('dv')]

all_features = p_features + v_features + dp_features + dv_features

targets = ['t0', 't1']

# Основной функционал

In [None]:
def feature_columns() -> List[str]:
    cols = []
    for prefix in ("p", "v"):
        cols += [f"{prefix}{i}" for i in range(12)]
    for prefix in ("dp", "dv"):
        cols += [f"{prefix}{i}" for i in range(4)]
    return cols


def target_columns() -> List[str]:
    return ["t0", "t1"]


def df_to_numpy_arrays(
    df: pd.DataFrame,
    seq_col: str = "seq_ix",
    step_col: str = "step_in_seq",
    need_pred_col: str = "need_prediction",
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Преобразует DataFrame в три numpy массива:
    X: (N_seq, T, F) dtype=float32
    Y: (N_seq, T, 2) dtype=float32
    mask: (N_seq, T) dtype=float32 (0/1 — use for loss)
    Требуется, чтобы все последовательности имели одинаковую длину T и
    шаги были 0..T-1 (или были упорядочены).
    """
    feat_cols = feature_columns()
    targ_cols = target_columns()

    df_sorted = df.sort_values([seq_col, step_col], kind="stable")

    seq_ids = df_sorted[seq_col].to_numpy()
    unique_seq_ids, counts = np.unique(seq_ids, return_counts=True)
    N_seq = len(unique_seq_ids)
    if counts.min() != counts.max():
        raise ValueError("Ожидается что все последовательности одной длины (T).")
    T = counts[0]

    X_all = df_sorted[feat_cols].to_numpy(dtype=np.float32)
    X = X_all.reshape(N_seq, T, len(feat_cols))
    X = np.ascontiguousarray(X)

    Y_all = df_sorted[targ_cols].to_numpy(dtype=np.float32)
    Y = Y_all.reshape(N_seq, T, len(targ_cols))
    Y = np.ascontiguousarray(Y)

    mask_all = df_sorted[need_pred_col].to_numpy(dtype=np.uint8)
    mask = mask_all.reshape(N_seq, T)
    mask = np.ascontiguousarray(mask)
    return X, Y, mask

In [None]:
class TimeSeriesDataset(Dataset):
    """
    Dataset возвращает целиком одну последовательность:
    X: (T, F)
    Y: (T, 2)
    mask: (T,)
    Все в torch.float32 на CPU (torch.from_numpy без копирования, если входные массивы contiguous).
    """
    def __init__(
        self,
        X_np: np.ndarray,
        Y_np: np.ndarray,
        mask_np: np.ndarray,
        precompute_weights: bool = True,
        eps: float = 1e-8
    ):
        assert X_np.ndim == 3
        assert Y_np.ndim == 3
        assert mask_np.ndim == 2
        self.X = X_np
        self.Y = Y_np
        self.mask = mask_np
        self.eps = eps
        if precompute_weights:
            w = np.abs(self.Y)
            w = np.maximum(w, float(eps))
            self.weights = np.ascontiguousarray(w.astype(np.float32))
        else:
            self.weights = None


    def __len__(self):
        return self.X.shape[0]


    def __getitem__(self, idx):
        x = torch.from_numpy(self.X[idx])  # (T, F)
        y = torch.from_numpy(self.Y[idx])  # (T, 2)
        mask = torch.from_numpy(self.mask[idx])  # (T,)
        if self.weights is not None:
            w = torch.from_numpy(self.weights[idx])  # (T,2)
        else:
            w = None
        return x, y, mask, w


def make_dataloader(
    dataset: Dataset,
    batch_size: int = 64,
    shuffle: bool = True,
    num_workers: int = 4,
    pin_memory: bool = True,
    persistent_workers: bool = True,
    prefetch_factor: int = 2,
) -> DataLoader:
    """
    Рекомендуемые параметры:
    - num_workers: 2..8 (зависит от CPU и диска)
    - pin_memory=True (если GPU)
    - persistent_workers=True (ускоряет при частых epoch)
    - prefetch_factor: 2..4
    """
    dl = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=pin_memory,
        persistent_workers=persistent_workers, # при повторном итерировании работает быстрее (однако походу перестает работать shuffle; использовать generator + worker_init_fn)
        prefetch_factor=prefetch_factor, # каждый воркер обрабатывает prefetch_factor*batch_size примеров
    )
    return dl

In [None]:
def batch_weighted_pearson_torch(y_true: torch.Tensor,
                                 y_pred: torch.Tensor,
                                 mask: torch.Tensor,
                                 weights: Optional[torch.Tensor] = None,
                                 clip_value: float = 6.0, # как в примере
                                 eps: float = 1e-8) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Вытягиваем последовательности внутри батча и считаем корреляцию.
    y_true, y_pred: (B,T,C)
    mask: (B,T) 0/1
    weights: (B,T,C) optional; if None -> use abs(y_true) as in scorer
    Returns:
    corr_per_channel (C,), mean_abs_corr scalar
    """
    B, T, C = y_true.shape
    device = y_true.device
    dtype = y_true.dtype

    # Flatten valid points inside batch
    m = mask.unsqueeze(-1)  # (B,T,1)
    valid_mask = (m > 0).reshape(-1)  # (B*T,); маска вида: 100 False, 900 True, 100 False, 900 True, ...
    if valid_mask.sum() == 0:
        corr = torch.zeros(C, device=device, dtype=dtype)
        return corr, torch.tensor(0.0, device=device, dtype=dtype)

    y = y_true.reshape(-1, C)[valid_mask]  # (B*T, C) # также вытягиваем вдоль батча
    yhat = y_pred.reshape(-1, C)[valid_mask]

    # clip predictions as in scorer
    yhat = torch.clamp(yhat, -clip_value, clip_value)

    if weights is None:
        w = torch.abs(y)
    else:
        w = weights.reshape(-1, C)[valid_mask]
    # floor weights
    w = torch.maximum(w, torch.tensor(eps, device=device, dtype=dtype))

    sum_w = w.sum(dim=0).clamp_min(eps)  # (C,)

    mean_y = (w * y).sum(dim=0) / sum_w
    mean_yhat = (w * yhat).sum(dim=0) / sum_w

    y_c = y - mean_y.unsqueeze(0)
    yhat_c = yhat - mean_yhat.unsqueeze(0)

    cov = (w * y_c * yhat_c).sum(dim=0) / sum_w
    var_y = (w * y_c * y_c).sum(dim=0) / sum_w
    var_yhat = (w * yhat_c * yhat_c).sum(dim=0) / sum_w

    denom = torch.sqrt(var_y * var_yhat).clamp_min(eps)
    corr = cov / denom
    corr = torch.clamp(corr, -1.0, 1.0)
    # mean_abs_corr = corr.abs().mean() # тут берем модуль от коэффициента корреляции, это будто бы лишнее
    mean_abs_corr = corr.mean()
    return corr, mean_abs_corr

# -------------------------
# 5) Combined loss (MSE + (- |corr|)), computed batch-wise with flattening inside batch
# -------------------------
class CombinedLoss(torch.nn.Module):
    def __init__(self,
                 alpha: float = 1.0,
                 beta: float = 1.0,
                 clip_value: float = 6.0,
                 eps: float = 1e-8,
                 weight_gamma: float = 1.0,
                 max_weight: Optional[float] = None):
        """
        weight_gamma: exponent for weight transform: w = |y|**gamma
        max_weight: if not None, cap weights at this value
        """
        super().__init__()
        self.alpha = float(alpha)
        self.beta = float(beta)
        self.clip_value = float(clip_value)
        self.eps = float(eps)
        self.gamma = float(weight_gamma)
        self.max_weight = float(max_weight) if max_weight is not None else None


    def forward(self,
                y_pred: torch.Tensor,  # (B,T,C)
                y_true: torch.Tensor,  # (B,T,C)
                mask: torch.Tensor,  # (B,T)
                weights: Optional[torch.Tensor] = None  # optional (B,T,C)
                ) -> Tuple[torch.Tensor, Dict[str, Any]]:

        device = y_true.device
        dtype = y_true.dtype
        B, T, C = y_true.shape

        # ----- Weighted MSE -----
        # if external weights provided, use them; else use amplitude-based weights |y|**gamma (веса возводим в некоторую степень; мб это лучше делать только для mse, а может и нет)
        if weights is None:
            w = torch.abs(y_true).pow(self.gamma)  # (B,T,C)
        else:
            w = weights.clone()
            if self.gamma != 1.0:
                w = w.pow(self.gamma)

        # floor and cap
        w = torch.maximum(w, torch.tensor(self.eps, device=device, dtype=dtype))
        if self.max_weight is not None:
            w = torch.minimum(w, torch.tensor(self.max_weight, device=device, dtype=dtype))

        # apply mask (zero-out warmup)
        m = mask.unsqueeze(-1)  # (B,T,1)
        w = w * m  # (B,T,C)

        # normalize per channel within batch (to make loss scale stable)
        sum_w = w.sum(dim=(0, 1)).clamp_min(self.eps)  # (C,); суммируем по батчу и по всем шагам в последовательности
        # weighted MSE per channel
        diff2 = (y_pred - y_true).pow(2) * w  # (B,T,C)
        mse_per_ch = diff2.sum(dim=(0, 1)) / sum_w  # (C,)
        mse_mean = mse_per_ch.mean()

        # ----- Pearson term (batch flattened) -----
        corr_per_ch, mean_abs_corr = batch_weighted_pearson_torch(
            y_true, y_pred, mask, weights=w, clip_value=self.clip_value, eps=self.eps
        )

        term_mse = mse_mean
        term_corr = - mean_abs_corr

        loss = self.alpha * term_mse + self.beta * term_corr

        info = {
            "loss": float(loss.detach().cpu()), # alpha*mse_mean - beta*mean_abs_corr
            "mse_mean": float(mse_mean.detach().cpu()), # среднее по 2 таргетам mse, np.mean(mse_per_ch) (мб надо тоже с весами усреднять)
            "mse_per_ch": mse_per_ch.detach().cpu().tolist(), # средневзвеш. mse по каждому из таргетов
            "corr_per_ch": corr_per_ch.detach().cpu().tolist(), # взвешенные модули коэф. корреляции по каждому из таргетов
            "mean_abs_corr": float(mean_abs_corr.detach().cpu()), # среднее значение взвешенных модулей коэф. корреляции по каждому из таргетов
            "term_mse": float(term_mse.detach().cpu()), # аналогично mse_mean
            "term_corr": float(term_corr.detach().cpu()), # аналогично -mean_abs_corr
        }
        return loss, info


# -------------------------
# 6) Numpy scorer (exact copy of user's numpy scorer) for eval
# -------------------------
def weighted_pearson_correlation_np(y_true: np.ndarray, y_pred: np.ndarray, clip_value: float = 6.0, eps: float = 1e-8) -> float:
    y_pred_clipped = np.clip(y_pred, -clip_value, clip_value)
    weights = np.abs(y_true)
    weights = np.maximum(weights, eps)

    sum_w = np.sum(weights)
    if sum_w == 0:
        return 0.0

    mean_true = np.sum(y_true * weights) / sum_w
    mean_pred = np.sum(y_pred_clipped * weights) / sum_w

    dev_true = y_true - mean_true
    dev_pred = y_pred_clipped - mean_pred

    cov = np.sum(weights * dev_true * dev_pred) / sum_w
    var_true = np.sum(weights * dev_true**2) / sum_w
    var_pred = np.sum(weights * dev_pred**2) / sum_w

    if var_true <= 0 or var_pred <= 0:
        return 0.0

    corr = cov / (np.sqrt(var_true) * np.sqrt(var_pred))
    return float(corr)


def calc_global_np_score(preds_all: np.ndarray, targets_all: np.ndarray, clip_value: float = 6.0) -> Dict[str, float]:
    """
    preds_all, targets_all: (N_total, C)
    returns dict with per-target weighted pearson and avg
    """
    C = preds_all.shape[1]
    scores = {}
    for c in range(C):
        scores[f"t{c+1}"] = weighted_pearson_correlation_np(targets_all[:, c], preds_all[:, c], clip_value=clip_value)
    scores["weighted_pearson"] = float(np.mean(list(scores.values())))
    return scores

In [None]:
def train_one_epoch(
    model: nn.Module,
    dataloader: DataLoader,
    optimizer: torch.optim.Optimizer,
    loss_fn: CombinedLoss,
    device: torch.device,
    scaler: Optional[torch.cuda.amp.GradScaler] = None,
    grad_clip: Optional[float] = 1.0,
    accum_steps: int = 1,
) -> Dict[str, Any]:
    model.train()
    total_loss = 0.0
    batches = 0
    # accumulators for logging
    mse_means = []
    corr_means = []
    corr_by_ch_means = []

    optimizer.zero_grad()
    for step, batch in enumerate(tqdm(dataloader, desc="train", leave=False)):
        X, Y, mask, weights = batch  # X:(B,T,F), Y:(B,T,2), mask:(B,T), weights:(B,T,2) or None
        X = X.to(device, non_blocking=True)
        Y = Y.to(device, non_blocking=True)
        mask = mask.to(device, non_blocking=True)
        if weights is not None:
            weights = weights.to(device, non_blocking=True)

        with torch.cuda.amp.autocast(enabled=(scaler is not None)): # автоматич. управление численной точностью; нужен scaler (torch.cuda.amp.GradScaler)
            preds, _ = model(X)
            loss, info = loss_fn(preds, Y, mask, weights)

        loss_value = loss / accum_steps # связано с расчетами в смешанной точности

        if scaler is None:
            loss_value.backward()
            if (step + 1) % accum_steps == 0:
                if grad_clip is not None:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
                optimizer.step()
                optimizer.zero_grad()
        else:
            scaler.scale(loss_value).backward()
            if (step + 1) % accum_steps == 0:
                if grad_clip is not None:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

        total_loss += float(loss.detach().cpu())
        mse_means.append(info["mse_mean"])
        corr_means.append(info["mean_abs_corr"])
        corr_by_ch_means.append(info["corr_per_ch"])
        batches += 1


    stats = {
        "train_loss": total_loss / max(1, batches), # средний лосс по батчам: 1/N * sum{i_1N} loss_i
        "train_mse_mean": float(np.mean(mse_means)) if len(mse_means) else 0.0, # средняя компонента mse лосса по батчам
        "train_mean_abs_corr": float(np.mean(corr_means)) if len(corr_means) else 0.0, # среднее значение коэффициента корреляции по батчам
        'train_mean_corr_by_ch': np.mean(corr_by_ch_means, axis=0).tolist() if len(corr_by_ch_means) else [0.0, 0.0],
        "batches": batches
    }
    return stats


def eval_one_epoch(
    model: nn.Module,
    dataloader: DataLoader,
    loss_fn: CombinedLoss,
    device: torch.device,
    clip_value: float = 6.0,
) -> Dict[str, Any]:
    model.eval()
    batch_mse_list = []
    batch_corr_list = []
    total_loss = 0.0
    batches = 0
    preds_list = []
    targets_list = []
    masks_list = []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="eval", leave=False):
            X, Y, mask, weights = batch
            X = X.to(device, non_blocking=True)
            Y = Y.to(device, non_blocking=True)
            mask = mask.to(device, non_blocking=True)
            if weights is not None:
                weights = weights.to(device, non_blocking=True)

            preds, _ = model(X)
            loss, info = loss_fn(preds, Y, mask, weights)

            total_loss += float(loss.detach().cpu())
            batches += 1

            batch_mse_list.append(info.get("mse_mean", 0.0))
            batch_corr_list.append(info.get("mean_abs_corr", 0.0))

            preds_list.append(preds.detach().cpu().numpy())
            targets_list.append(Y.detach().cpu().numpy())
            masks_list.append(mask.detach().cpu().numpy())

    # compute batchwise-averaged diagnostics
    val_mse_mean = float(np.mean(batch_mse_list)) if batch_mse_list else 0.0
    val_mean_abs_corr_batchwise = float(np.mean(batch_corr_list)) if batch_corr_list else 0.0
    val_loss = total_loss / max(1, batches)

    preds_all = np.concatenate(preds_list, axis=0) if preds_list else np.empty((0,0,0))
    targets_all = np.concatenate(targets_list, axis=0) if targets_list else np.empty((0,0,0))
    masks_all = np.concatenate(masks_list, axis=0) if masks_list else np.empty((0,0))

    # Now create 2D arrays of valid points by flattening batch axis and time
    if preds_all.size == 0:
        global_scores = {"t1": 0.0, "t2": 0.0, "weighted_pearson": 0.0}
    else:
        B_total, T, C = preds_all.shape
        preds_flat = preds_all.reshape(-1, C)
        targets_flat = targets_all.reshape(-1, C)
        masks_flat = masks_all.reshape(-1)
        valid_idx = masks_flat > 0
        if valid_idx.sum() == 0:
            global_scores = {"t1": 0.0, "t2": 0.0, "weighted_pearson": 0.0}
        else:
            preds_valid = preds_flat[valid_idx]
            targets_valid = targets_flat[valid_idx]
            # global_scores = calc_global_np_score(preds_valid, targets_valid, clip_value=clip_value) # вытянули все последовательности и считаем метрику соревнования

            per_target_corrs = []
            for ch in range(C):
                per_target_corrs.append(
                    weighted_pearson_correlation_np(targets_valid[:, ch], preds_valid[:, ch], clip_value=clip_value)
                )
            weighted_pearson = float(np.mean(per_target_corrs))

            # also compute global weighted MSE (same weights as correlation: w=|y|)
            weights = np.maximum(np.abs(targets_valid), 1e-8)
            sum_w = weights.sum(axis=0)
            # avoid division by zero
            mse_per_ch = []
            for ch in range(C):
                w = weights[:, ch]
                denom = sum_w[ch] if sum_w[ch] > 0 else 1.0
                mse_ch = np.sum(w * (preds_valid[:, ch] - targets_valid[:, ch])**2) / denom
                mse_per_ch.append(mse_ch)
            global_mse = float(np.mean(mse_per_ch))

            global_scores = {f"t{ch+1}": per_target_corrs[ch] for ch in range(C)}
            global_scores["weighted_pearson"] = weighted_pearson # метрика соревнования
            global_scores["global_mse"] = global_mse # mse, посчитанное на всей валидационной выборке

    stats = {
        "val_loss": val_loss, # средний по батчам лосс
        "val_mse_mean": val_mse_mean,  # batchwise average MSE (среднее по батчам mse)
        "val_mean_abs_corr_batchwise": val_mean_abs_corr_batchwise,  # batchwise average corr (средний по батчам коэффициент корреляции)
    }
    stats.update(global_scores)

    return stats


# -------------------------
# 8) fit function
# -------------------------
# def fit(
#     model: nn.Module,
#     train_loader: DataLoader,
#     val_loader: DataLoader,
#     device: torch.device,
#     num_epochs: int = 30,
#     lr: float = 1e-3,
#     weight_decay: float = 1e-5,
#     alpha: float = 1.0,
#     beta: float = 1.0,
#     weight_gamma: float = 1.0,
#     grad_clip: float = 1.0,
#     accum_steps: int = 1,
#     use_amp: bool = True,
#     save_path: str = "best_model.pt",
#     monitor: str = "weighted_pearson",  # metric from eval_one_epoch to maximize
#     maximize_monitor: bool = True,
#     log_dir: str = "runs/experiment"
# ):
#     model = model.to(device)
#     writer = SummaryWriter(log_dir=log_dir)
#     optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
#     scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max" if maximize_monitor else "min",
#                                                             factor=0.5, patience=3, verbose=True)
#     loss_fn = CombinedLoss(alpha=alpha, beta=beta, weight_gamma=weight_gamma)
#     scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

#     best_metric = None
#     history = []

#     for epoch in range(1, num_epochs + 1):
#         train_stats = train_one_epoch(model, train_loader, optimizer, loss_fn, device, scaler if use_amp else None, grad_clip=grad_clip, accum_steps=accum_steps)
#         val_stats = eval_one_epoch(model, val_loader, loss_fn, device)
#         print(val_stats)

#         cur_metric = val_stats.get(monitor, val_stats.get("weighted_pearson", 0.0))
#         # scheduler step
#         scheduler.step(cur_metric)

#         writer.add_scalar('Loss/train', train_stats['train_loss'], epoch)
#         writer.add_scalar('Loss/val', val_stats['val_loss'], epoch)
#         writer.add_scalar('Corr/train', train_stats['train_mean_abs_corr'], epoch)
#         writer.add_scalar('Corr/val', val_stats.get('weighted_pearson', 0.0), epoch)

#         improved = (best_metric is None) or (maximize_monitor and cur_metric > best_metric) or (not maximize_monitor and cur_metric < best_metric)
#         if improved:
#             best_metric = cur_metric
#             torch.save({
#                 "epoch": epoch,
#                 "model_state_dict": model.state_dict(),
#                 "optimizer_state_dict": optimizer.state_dict(),
#                 "best_metric": best_metric,
#             }, save_path)
#             print(f"Saved new best model at epoch {epoch}: {best_metric:.6f}")

#         # log
#         row = {"epoch": epoch}
#         row.update(train_stats)
#         row.update(val_stats)
#         history.append(row)

#         print(f"Epoch {epoch:03d} | train_loss={train_stats['train_loss']:.6f} train_corr={train_stats['train_mean_abs_corr']:.6f} train_corr_by_ch={train_stats['train_mean_corr_by_ch']}"
#               f"| val_loss={val_stats['val_loss']:.6f} val_weighted_pearson={val_stats.get('weighted_pearson', 0.0):.6f} val_corr_by_ch={val_stats.get('t1', 0.0), val_stats.get('t2', 0.0)}")

#     writer.close()

#     return model, history


def fit(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    device: torch.device,
    num_epochs: int = 30,
    lr: float = 1e-3,
    weight_decay: float = 1e-5,
    alpha: float = 1.0,
    beta: float = 1.0,
    weight_gamma: float = 1.0,
    grad_clip: float = 1.0,
    accum_steps: int = 1,
    use_amp: bool = True,
    save_path: str = "best_model.pt",
    monitor: str = "weighted_pearson",  # metric from eval_one_epoch to maximize
    maximize_monitor: bool = True,
    patience: int = 5  # early stopping patience (epochs without improvement)
):
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max" if maximize_monitor else "min",
                                                            factor=0.5, patience=3, verbose=True)
    loss_fn = CombinedLoss(alpha=alpha, beta=beta, weight_gamma=weight_gamma)
    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

    best_metric = None
    history = []
    epochs_no_improve = 0  # counter for early stopping

    for epoch in range(1, num_epochs + 1):
        train_stats = train_one_epoch(
            model,
            train_loader,
            optimizer,
            loss_fn,
            device,
            scaler if use_amp else None,
            grad_clip=grad_clip,
            accum_steps=accum_steps,
        )

        val_stats = eval_one_epoch(model, val_loader, loss_fn, device)

        cur_metric = val_stats.get(monitor, val_stats.get("weighted_pearson", 0.0))

        # scheduler step
        try:
            scheduler.step(cur_metric)
        except Exception:
            # some schedulers expect no-arg step; ignore scheduler errors to be robust
            try:
                scheduler.step()
            except Exception:
                pass

        # check improvement
        improved = (best_metric is None) or (maximize_monitor and cur_metric > best_metric) or (not maximize_monitor and cur_metric < best_metric)

        if improved:
            best_metric = cur_metric
            epochs_no_improve = 0  # reset counter on improvement
            torch.save(
                {
                    "epoch": epoch,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "best_metric": best_metric,
                },
                save_path,
            )
            print(f"Saved new best model at epoch {epoch}: {best_metric:.6f}")
        else:
            epochs_no_improve += 1
            print(
                f"No improvement for {epochs_no_improve}/{patience} epochs "
                f"(cur {monitor}={cur_metric:.6f}, best={best_metric:.6f})"
            )

        # log
        row = {"epoch": epoch}
        row.update(train_stats)
        row.update(val_stats)
        history.append(row)

        print(f"Epoch {epoch:03d} | train_loss={train_stats['train_loss']:.6f} train_corr={train_stats['train_mean_abs_corr']:.6f} train_corr_by_ch={train_stats['train_mean_corr_by_ch']}"
              f"| val_loss={val_stats['val_loss']:.6f} val_weighted_pearson={val_stats.get('weighted_pearson', 0.0):.6f} val_corr_by_ch={val_stats.get('t1', 0.0), val_stats.get('t2', 0.0)}")

        # Early stop: break fold training and return (so cross_validate proceeds to next fold)
        if epochs_no_improve >= patience:
            print(
                f"Early stopping triggered (no improvement for {patience} epochs). "
                "Stopping training for this fold."
            )
            break

    # load best checkpoint if exists (so returned model is the best one)
    try:
        if os.path.exists(save_path):
            ckpt = torch.load(save_path, map_location=device)
            model.load_state_dict(ckpt["model_state_dict"])
            print(
                f"Loaded best model from {save_path} "
                f"(epoch {ckpt.get('epoch', '?')}, metric={ckpt.get('best_metric')})"
            )
    except Exception as e:
        print("Warning: could not load best checkpoint:", e)

    return model, history

In [None]:
import random
import os
def seed_worker(worker_id):
    """Инициализация сида для каждого воркера DataLoader"""
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def cross_validate(
    X: np.ndarray,
    Y: np.ndarray,
    mask: np.ndarray,
    n_splits: int = 5,
    seed: int = 42,
    model_ctor_kwargs: Optional[dict] = None,
    fit_kwargs: Optional[dict] = None,
    dataloader_kwargs: Optional[dict] = None,
    out_dir: str = "cv_results",
    trial: "optuna.trial.Trial" = None,  # optional for pruning
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
    """
    Run KFold cross-validation over sequences.

    Args:
        X, Y, mask : numpy arrays shaped (N_seq, T, F), (N_seq, T, C), (N_seq, T)
        n_splits: number of folds
        seed: random seed for reproducibility
        model_ctor_kwargs: kwargs for GRU constructor:
            input_size, hidden_size, output_size, dropout
        fit_kwargs: kwargs that will be forwarded to fit(...)
            (num_epochs, lr, etc.)
        dataloader_kwargs: kwargs forwarded to DataLoader
            (batch_size, num_workers, pin_memory)
        out_dir: directory to save fold models and logs

    Returns:
        (fold_results, summary) where fold_results is list of per-fold dicts,
        and summary contains mean/std for main metrics.
    """

    os.makedirs(out_dir, exist_ok=True)

    if model_ctor_kwargs is None:
        model_ctor_kwargs = {
            "input_size": X.shape[2],
            "hidden_size": 256,
            "output_size": Y.shape[2],
            "dropout": 0.1,
        }
    if fit_kwargs is None:
        fit_kwargs = {
            "num_epochs": 20,
            "lr": 1e-3,
            "weight_decay": 1e-5,
            "alpha": 1.0,
            "beta": 3.0,
            "grad_clip": 1.0,
            "accum_steps": 1,
            "use_amp": True,
            "save_path": None,  # will be set per-fold
            "monitor": "weighted_pearson",
            "maximize_monitor": True,
        }
    if dataloader_kwargs is None:
        dataloader_kwargs = {
            "batch_size": 16,
            "num_workers": 4,
            "pin_memory": True,
        }


    kf = KFold(n_splits=n_splits, shuffle=True, random_state=seed-16)
    n_seq = X.shape[0]
    indices = np.arange(n_seq)

    fold_results: List[Dict[str, Any]] = []

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Set seeds for reproducibility
    # np.random.seed(seed)
    # torch.manual_seed(seed)
    # if torch.cuda.is_available():
    #     torch.cuda.manual_seed_all(seed)

    for fold, (train_idx, val_idx) in enumerate(kf.split(indices), start=1):
        print(f"\n=== Fold {fold}/{n_splits} ===")
        seed_everything(seed+fold)
        # Создаём генератор для воспроизводимого перемешивания в DataLoader
        generator = torch.Generator()
        generator.manual_seed(seed + fold)

        X_train, Y_train, mask_train = X[train_idx], Y[train_idx], mask[train_idx]
        X_val, Y_val, mask_val = X[val_idx], Y[val_idx], mask[val_idx]

        ds_train = TimeSeriesDataset(X_train, Y_train, mask_train,
                                     precompute_weights=True)
        ds_val = TimeSeriesDataset(X_val, Y_val, mask_val,
                                   precompute_weights=True)

        train_loader = DataLoader(ds_train, shuffle=True, **{**dataloader_kwargs, "worker_init_fn": seed_worker, "generator": generator})
        # larger batch for eval if memory allows
        val_batch_size = dataloader_kwargs.get("batch_size", 16) * 2
        val_loader = DataLoader(
            ds_val,
            shuffle=False,
            **{**dataloader_kwargs, "batch_size": val_batch_size}
        )


        model = BaselineGRU(**model_ctor_kwargs)

        fold_save = os.path.join(out_dir, f"best_model_fold{fold}.pt")
        fit_kwargs_local = dict(fit_kwargs)
        fit_kwargs_local["save_path"] = fold_save



        model, history = fit(
            model,
            train_loader,
            val_loader,
            device,
            **fit_kwargs_local
        )

        # Load best checkpoint
        if os.path.exists(fold_save):
            ckpt = torch.load(fold_save, map_location=device)
            model.load_state_dict(ckpt["model_state_dict"])
            print(f"Loaded best checkpoint for fold {fold} "
                  f"(epoch {ckpt.get('epoch', '?')})")
        else:
            print(f"Warning: no checkpoint saved for fold {fold}, "
                  "using final model")

        # Final validation metrics
        loss_fn = CombinedLoss(
            alpha=fit_kwargs_local.get("alpha", 1.0),
            beta=fit_kwargs_local.get("beta", 1.0),
            weight_gamma=fit_kwargs_local.get("weight_gamma", 1.0)
        )
        val_stats = eval_one_epoch(model, val_loader, loss_fn, device)

        # --- Report to Optuna pruner ---
        monitor_metric = fit_kwargs_local.get("monitor", "weighted_pearson")
        fold_metric = float(
            val_stats.get(monitor_metric, val_stats.get("weighted_pearson", 0.0))
        )
        if trial is not None:
            trial.report(fold_metric, fold)
            if trial.should_prune():
                trial.set_user_attr("pruned_after_fold", fold)
                raise optuna.exceptions.TrialPruned()

        fold_info = {
            "fold": fold,
            "train_size": len(train_idx),
            "val_size": len(val_idx),
            "val_stats": val_stats,
            "history": history,
            "model_path": fold_save,
        }
        fold_results.append(fold_info)

    # Aggregate results
    pears = []
    mses = []
    for fr in fold_results:
        vs = fr["val_stats"]
        pears.append(vs.get("weighted_pearson", 0.0))
        mses.append(vs.get("global_mse", 0.0))

    pears = np.array(pears)
    mses = np.array(mses)

    summary = {
        "n_folds": n_splits,
        "pearson_mean": float(np.mean(pears)),
        "pearson_std": float(np.std(pears, ddof=0)),
        "global_mse_mean": float(np.mean(mses)),
        "global_mse_std": float(np.std(mses, ddof=0)),
    }

    print("\n=== CV Summary ===")
    print(f"Weighted Pearson per-fold: {pears.tolist()}")
    print(f"Mean ± Std = {summary['pearson_mean']:.6f} ± {summary['pearson_std']:.6f}")
    print(f"Global MSE per-fold: {mses.tolist()}")
    print(f"Mean ± Std = {summary['global_mse_mean']:.6f} ± {summary['global_mse_std']:.6f}")

    return fold_results, summary



# def cross_validate(
#     X: np.ndarray,
#     Y: np.ndarray,
#     mask: np.ndarray,
#     n_splits: int = 5,
#     seed: int = 42,
#     # models_random_states: List = [],
#     model_ctor_kwargs: Optional[dict] = None,
#     fit_kwargs: Optional[dict] = None,
#     dataloader_kwargs: Optional[dict] = None,
#     out_dir: str = "cv_results",
# ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
#     """
#     Run KFold cross-validation over sequences.

#     Args:
#         X, Y, mask : numpy arrays shaped (N_seq, T, F), (N_seq, T, C), (N_seq, T)
#         n_splits: number of folds
#         seed: random seed for reproducibility
#         model_ctor_kwargs: kwargs for GRU constructor:
#             input_size, hidden_size, output_size, dropout
#         fit_kwargs: kwargs that will be forwarded to fit(...)
#             (num_epochs, lr, etc.)
#         dataloader_kwargs: kwargs forwarded to DataLoader
#             (batch_size, num_workers, pin_memory)
#         out_dir: directory to save fold models and logs

#     Returns:
#         (fold_results, summary) where fold_results is list of per-fold dicts,
#         and summary contains mean/std for main metrics.
#     """

#     os.makedirs(out_dir, exist_ok=True)

#     # if models_random_states == []:
#     #     self.seeds = [42] * n_splits
#     # elif len(models_random_states) != n_splits:
#     #     raise ValueError("Длина models_random_states должна совпадать с n_splits")

#     # defaults
#     if model_ctor_kwargs is None:
#         model_ctor_kwargs = {
#             "input_size": X.shape[2],
#             "hidden_size": 256,
#             "output_size": Y.shape[2],
#             "dropout": 0.1,
#         }
#     if fit_kwargs is None:
#         fit_kwargs = {
#             "num_epochs": 20,
#             "lr": 1e-3,
#             "weight_decay": 1e-5,
#             "alpha": 1.0,
#             "beta": 3.0,
#             "grad_clip": 1.0,
#             "accum_steps": 1,
#             "use_amp": True,
#             "save_path": None,          # will be set per-fold
#             "monitor": "weighted_pearson",
#             "maximize_monitor": True,
#         }
#     if dataloader_kwargs is None:
#         dataloader_kwargs = {
#             "batch_size": 16,
#             "num_workers": 4,
#             "pin_memory": True,
#         }

#     # Prepare KFold
#     kf = KFold(n_splits=n_splits, shuffle=True, random_state=seed-16)
#     n_seq = X.shape[0]
#     indices = np.arange(n_seq)

#     fold_results: List[Dict[str, Any]] = []

#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#     # For reproducibility: set seeds
#     # np.random.seed(seed)
#     # torch.manual_seed(seed)
#     # if torch.cuda.is_available():
#     #     torch.cuda.manual_seed_all(seed)

#     for fold, (train_idx, val_idx) in enumerate(kf.split(indices), start=1):
#     # for fold, vals in enumerate(zip(kf.split(train_data), models_random_states), start=1):
#         print(f"\n=== Fold {fold}/{n_splits} ===")
#         # train_idx, val_idx = vals[0][0], vals[0][1]
#         # rs = vals[1]
#         seed_everything(seed+fold)
#         # Создаём генератор для воспроизводимого перемешивания в DataLoader
#         generator = torch.Generator()
#         generator.manual_seed(seed + fold)
#         # prepare datasets/loaders
#         X_train, Y_train, mask_train = X[train_idx], Y[train_idx], mask[train_idx]
#         X_val, Y_val, mask_val = X[val_idx], Y[val_idx], mask[val_idx]

#         ds_train = TimeSeriesDataset(X_train, Y_train, mask_train,
#                                      precompute_weights=True)
#         ds_val = TimeSeriesDataset(X_val, Y_val, mask_val,
#                                    precompute_weights=True)

#         train_loader = DataLoader(ds_train, shuffle=True, **{**dataloader_kwargs, "worker_init_fn": seed_worker, "generator": generator})
#         # For eval we can use larger batch size if memory allows
#         val_batch_size = dataloader_kwargs.get("batch_size", 16) * 2
#         val_loader = DataLoader(
#             ds_val,
#             shuffle=False,
#             **{**dataloader_kwargs, "batch_size": val_batch_size}
#         )

#         # build model fresh for each fold
#         model = BaselineGRU(**model_ctor_kwargs)

#         # unique save path per fold
#         fold_save = os.path.join(out_dir, f"best_model_fold{fold}.pt")
#         fit_kwargs_local = dict(fit_kwargs)
#         fit_kwargs_local["save_path"] = fold_save

#         # Train
#         model, history = fit(
#             model,
#             train_loader,
#             val_loader,
#             device,
#             **fit_kwargs_local
#         )

#         # load best checkpoint (fit saved best to fold_save)
#         if os.path.exists(fold_save):
#             ckpt = torch.load(fold_save, map_location=device)
#             model.load_state_dict(ckpt["model_state_dict"])
#             print(f"Loaded best checkpoint for fold {fold} "
#                   f"(epoch {ckpt.get('epoch', '?')})")
#         else:
#             print(f"Warning: no checkpoint saved for fold {fold}, "
#                   "using final model")

#         # final evaluation on validation (global metric computation)
#         loss_fn = CombinedLoss(
#             alpha=fit_kwargs_local.get("alpha", 1.0),
#             beta=fit_kwargs_local.get("beta", 1.0),
#             weight_gamma=fit_kwargs_local.get("weight_gamma", 1.0)
#         )
#         val_stats = eval_one_epoch(model, val_loader, loss_fn, device)
#         # eval_one_epoch returns dict with keys including
#         # "weighted_pearson" and "global_mse" (see previously)

#         # keep fold info
#         fold_info = {
#             "fold": fold,
#             "train_size": len(train_idx),
#             "val_size": len(val_idx),
#             "val_stats": val_stats,
#             "history": history,          # may be large; drop if not needed
#             "model_path": fold_save,
#         }
#         fold_results.append(fold_info)

#     # Aggregate results
#     # Extract weighted_pearson and global_mse per fold
#     pears = []
#     mses = []
#     for fr in fold_results:
#         vs = fr["val_stats"]
#         pears.append(vs.get("weighted_pearson", 0.0))
#         mses.append(vs.get("global_mse", 0.0))

#     pears = np.array(pears)
#     mses = np.array(mses)

#     summary = {
#         "n_folds": n_splits,
#         "pearson_mean": float(np.mean(pears)),
#         "pearson_std": float(np.std(pears, ddof=0)),
#         "global_mse_mean": float(np.mean(mses)),
#         "global_mse_std": float(np.std(mses, ddof=0)),
#     }

#     print("\n=== CV Summary ===")
#     print(f"Weighted Pearson per-fold: {pears.tolist()}")
#     print(f"Mean ± Std = {summary['pearson_mean']:.6f} ± "
#           f"{summary['pearson_std']:.6f}")
#     print(f"Global MSE per-fold: {mses.tolist()}")
#     print(f"Mean ± Std = {summary['global_mse_mean']:.6f} ± "
#           f"{summary['global_mse_std']:.6f}")

#     return fold_results, summary

In [None]:
def ensemble_inference(
    model_paths: List[str],
    dataloader: DataLoader,
    device: torch.device,
    model_ctor,
    model_kwargs: dict,
) -> np.ndarray:
    """
    Делает инференс ансамбля моделей с усреднением предсказаний.

    Args:
        model_paths: список путей к чекпоинтам моделей
        dataloader: DataLoader для инференса
        device: torch.device
        model_ctor: конструктор модели (например ThreeLayerGRU)
        model_kwargs: аргументы конструктора модели

    Returns:
        preds_mean: numpy array (N_seq, T, C)
    """

    all_model_preds = []

    for path in model_paths:
        print(f"Loading model: {path}")

        # создать модель
        model = model_ctor(**model_kwargs)
        ckpt = torch.load(path, map_location=device)
        model.load_state_dict(ckpt["model_state_dict"])
        model.to(device)
        model.eval()

        preds_list = []

        with torch.no_grad():
            for batch in dataloader:
                X = batch[0].to(device, non_blocking=True)

                preds, _ = model(X)                     # (B, T, C)
                preds_list.append(preds.cpu().numpy())

        preds_model = np.concatenate(preds_list, axis=0)   # (N_seq, T, C)
        all_model_preds.append(preds_model)

    # stack: (n_models, N_seq, T, C)
    all_model_preds = np.stack(all_model_preds, axis=0)

    # усреднение по моделям
    preds_mean = np.mean(all_model_preds, axis=0)

    return preds_mean


# if __name__ == "__main__":
#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#     model_paths = [
#         "cv_out/best_model_fold1.pt",
#         "cv_out/best_model_fold2.pt",
#         "cv_out/best_model_fold3.pt",
#         "cv_out/best_model_fold4.pt",
#         "cv_out/best_model_fold5.pt",
#     ]

#     model_kwargs = {
#         "input_size": 32,
#         "hidden_size": 256,
#         "output_size": 2,
#         "dropout": 0.1,
#     }

#     # dataloader для теста
#     test_loader = DataLoader(
#         test_dataset,
#         batch_size=32,
#         shuffle=False,
#         num_workers=4,
#         pin_memory=True,
#     )

#     preds = ensemble_inference(
#         model_paths=model_paths,
#         dataloader=test_loader,
#         device=device,
#         model_ctor=ThreeLayerGRU,
#         model_kwargs=model_kwargs,
#     )

#     print(preds.shape)

In [None]:
import optuna
import mlflow
import mlflow.exceptions
import mlflow.tracking
from datetime import datetime
from typing import Any, Dict

# Импортируй свою cross_validate из того модуля, где ты её держишь
# from your_module import cross_validate

# --------------------------
# Helper: сохранять прогресс Optuna в CSV/JSON
# --------------------------
def save_study_progress(study: optuna.study.Study, trial: optuna.trial.FrozenTrial, out_dir: str):
    """
    Optuna callback: сохраняет таблицу trials в CSV и небольшой summary в JSON.
    Вызывается после завершения каждого trial.
    """
    os.makedirs(out_dir, exist_ok=True)
    df = study.trials_dataframe()
    csv_path = os.path.join(out_dir, "optuna_trials.csv")
    df.to_csv(csv_path, index=False)

    # Save a compact JSON with per-trial summary useful for quick look
    summary = []
    for t in study.trials:
        summary.append({
            "number": t.number,
            "state": str(t.state),
            "value": None if t.value is None else float(t.value),
            "params": t.params,
            "datetime_start": str(t.datetime_start),
            "datetime_complete": str(t.datetime_complete),
            "user_attrs": t.user_attrs,
            "system_attrs": t.system_attrs,
        })
    json_path = os.path.join(out_dir, "optuna_summary.json")
    with open(json_path, "w") as f:
        json.dump(summary, f, indent=2, default=str)

    # Optionally, pickle the study object for later analysis
    pkl_path = os.path.join(out_dir, "optuna_study.pkl")
    try:
        with open(pkl_path, "wb") as pf:
            pickle.dump(study, pf)
    except Exception as e:
        # pickling Study may fail in some setups (e.g., if DB storage used) — ignore quietly
        print("Warning: could not pickle study:", e)

# --------------------------
# Main: make objective with MLflow logging
# --------------------------
def make_objective_with_mlflow(
    X: np.ndarray,
    Y: np.ndarray,
    mask: np.ndarray,
    cross_validate_fn,
    n_splits: int = 5,
    seed: int = 42,
    base_model_ctor_kwargs: Dict[str, Any] = None,
    base_fit_kwargs: Dict[str, Any] = None,
    base_dataloader_kwargs: Dict[str, Any] = None,
    mlflow_experiment: str = "optuna_cv",
    log_models_as_artifacts: bool = False,
    optuna_progress_dir: str = "optuna_progress",
):
    """
    Возвращает objective(trial) для optuna, который логирует в MLflow и сохраняет прогресс trials в CSV/JSON.
    - cross_validate_fn: твоя функция cross_validate(X,Y,mask, ...)
    - log_models_as_artifacts: если True — будут загружены в MLflow все файлы из trial_dir (может быть много, осторожно).
    """
    if base_model_ctor_kwargs is None:
        base_model_ctor_kwargs = {"input_size": X.shape[2], "hidden_size": 256, "output_size": Y.shape[2], "dropout": 0.1}
    if base_fit_kwargs is None:
        base_fit_kwargs = {
            "num_epochs": 50,
            "lr": 1e-3,
            "weight_decay": 1e-5,
            "alpha": 1.0,
            "beta": 3.0,
            "grad_clip": 1.0,
            "accum_steps": 1,
            "use_amp": True,
            "save_path": None,
            "monitor": "weighted_pearson",
            "maximize_monitor": True,
        }
    if base_dataloader_kwargs is None:
        base_dataloader_kwargs = {"batch_size": 16, "num_workers": 4, "pin_memory": True}

    # ensure mlflow experiment exists
    mlflow.set_experiment(mlflow_experiment)

# как считается score на cv?
    def objective(trial: optuna.trial.Trial):
        # ----------------------
        # Optuna search space (пример — настраивай под себя)
        # ----------------------
        hidden_size = 2**trial.suggest_int("hidden_size", 4, 9)
        dropout = trial.suggest_float("dropout", 0.0, 0.5, step=0.05)
        lr = trial.suggest_loguniform("lr", 1e-5, 5e-3)
        weight_decay = trial.suggest_loguniform("weight_decay", 1e-15, 1e-2)
        # batch_size = trial.suggest_categorical("batch_size", [256])
        batch_size = 256
        alpha = trial.suggest_float("alpha", 0.01, 1.0)
        # beta = trial.suggest_float("beta", 0.5, 6.0)
        beta = 1-alpha
        weight_gamma = trial.suggest_float("weight_gamma", 0.4, 1.0, step=0.05)
        max_weight = 50

        num_layers = trial.suggest_int("num_layers", 1, 5)
        in_layernorm = trial.suggest_categorical("in_layernorm",[True,False])
        out_layernorm = trial.suggest_categorical("out_layernorm",[True,False])

        # ----------------------
        # Prepare kwargs for cross_validate
        # ----------------------
        model_ctor_kwargs = dict(base_model_ctor_kwargs)
        model_ctor_kwargs.update({
            "input_size": X.shape[2],
            "hidden_size": hidden_size,
            "output_size": Y.shape[2],
            "dropout": dropout,
            "num_layers": num_layers,
            "in_layernorm": in_layernorm,
            "out_layernorm": out_layernorm
        })

        fit_kwargs = dict(base_fit_kwargs)
        fit_kwargs.update({
            "lr": lr,
            "weight_decay": weight_decay,
            "alpha": alpha,
            "weight_gamma": weight_gamma#,
            # "max_weight": max_weight
            # ensure save_path is templated; cross_validate will set per-fold paths
            # we'll create per-trial folder below
        })

        dataloader_kwargs = dict(base_dataloader_kwargs)
        dataloader_kwargs.update({"batch_size": batch_size})
        print(dataloader_kwargs)

        # ----------------------
        # Create trial-specific directory (contains per-fold models saved by cross_validate)
        # ----------------------
        trial_dir = os.path.join("optuna_trials", f"trial_{trial.number}_{datetime.now().strftime('%Y%m%d_%H%M%S')}")
        os.makedirs(trial_dir, exist_ok=True)
        fit_kwargs_local = dict(fit_kwargs)
        fit_kwargs_local["save_path"] = os.path.join(trial_dir, "best_model_fold{fold}.pt")

        # ----------------------
        # Start MLflow run (one run per trial)
        # ----------------------
        # Using nested runs makes it easy to view all trials under the experiment
        with mlflow.start_run(run_name=f"optuna_trial_{trial.number}", nested=True):
            # log trial params (some from trial.suggest_*)
            mlflow.log_param("hidden_size", hidden_size)
            mlflow.log_param("num_layers", num_layers)
            mlflow.log_param("dropout", dropout)
            mlflow.log_param("in_layernorm", in_layernorm)
            mlflow.log_param("out_layernorm", out_layernorm)
            mlflow.log_param("lr", lr)
            mlflow.log_param("weight_decay", weight_decay)
            mlflow.log_param("batch_size", batch_size)
            mlflow.log_param("alpha", alpha)
            mlflow.log_param("weight_gamma", weight_gamma)
            mlflow.log_param("n_splits", n_splits)
            mlflow.log_param("seed", seed)

            # Optionally store the full dict of trial.params
            mlflow.set_tag("optuna_trial_number", str(trial.number))

            # ----------------------
            # Run cross-validation (costly)
            # ----------------------
            # try:
            #     fold_results, summary = cross_validate_fn(
            #         X, Y, mask,
            #         n_splits=n_splits,
            #         seed=seed,
            #         model_ctor_kwargs=model_ctor_kwargs,
            #         fit_kwargs=fit_kwargs_local,
            #         dataloader_kwargs=dataloader_kwargs,
            #         out_dir=trial_dir,
            #     )
            try:
                fold_results, summary = cross_validate_fn(
                    X, Y, mask,
                    n_splits=n_splits,
                    seed=seed,
                    model_ctor_kwargs=model_ctor_kwargs,
                    fit_kwargs=fit_kwargs_local,
                    dataloader_kwargs=dataloader_kwargs,
                    out_dir=trial_dir,
                    trial=trial,  # <-- передаём trial
                )
            except optuna.exceptions.TrialPruned:
                # important: re-raise so Optuna can mark trial as pruned
                print(f"Trial {trial.number} pruned by pruner during cross-validation.")
                raise
            except Exception as e:
                # Log exception into MLflow and as trial attr, then re-raise (trial will be failed)
                mlflow.log_param("failed", True)
                mlflow.set_tag("error", str(e))
                trial.set_user_attr("error", str(e))
                raise

            # ----------------------
            # Log summary metrics & artifacts to MLflow
            # ----------------------
            # summary is a dict returned by cross_validate (pearson_mean, etc.)
            if summary is not None:
                for k, v in summary.items():
                    try:
                        mlflow.log_metric(k, float(v))
                    except Exception:
                        # ignore non-float entries
                        mlflow.log_param(f"metric_{k}", str(v))

            # Save fold_results as JSON artifact
            try:
                fr_json = os.path.join(trial_dir, "fold_results.json")
                with open(fr_json, "w") as fw:
                    json.dump(fold_results, fw, default=str, indent=2)
                mlflow.log_artifact(fr_json, artifact_path="fold_results")
            except Exception as e:
                print("Warning: cannot save fold_results.json:", e)

            # Optionally: log all files saved in trial_dir (models, logs)
            if log_models_as_artifacts:
                try:
                    mlflow.log_artifacts(trial_dir, artifact_path="trial_files")
                except Exception as e:
                    print("Warning: mlflow.log_artifacts failed:", e)

            # Store summary and fold_results in trial.user_attrs for offline analysis
            print(fold_results)
            trial.set_user_attr("summary", summary)
            trial.set_user_attr("fold_results", fold_results)

            # return the metric to optuna (the metric we want to maximize)
            pearson_mean = summary.get("pearson_mean", None)
            if pearson_mean is None:
                # if there is no valid metric, mark as pruned/failed; here we'll prune
                raise optuna.exceptions.TrialPruned()

            return float(pearson_mean)

    return objective

In [None]:
def run_optuna_search(
    X, Y, mask,
    cross_validate_fn,
    n_trials: int = 30,
    n_splits: int = 10,
    seed: int = 42,
    study_name: str = "optuna_mlflow_cv",
    storage_url: str = "sqlite:///optuna_study.db",
    base_model_ctor_kwargs: Dict[str, Any] = None,
    base_fit_kwargs: Dict[str, Any] = None,
    base_dataloader_kwargs: Dict[str, Any] = None,
    mlflow_experiment: str = "optuna_cv",
    log_models_as_artifacts: bool = False,
    optuna_progress_dir: str = "optuna_progress",
    n_jobs: int = 1
):
    # create study
    sampler = optuna.samplers.TPESampler(seed=seed)
    pruner = optuna.pruners.MedianPruner(n_warmup_steps=1) # можно настраивать
    study = optuna.create_study(study_name=study_name, storage=storage_url,
                                sampler=sampler, pruner=pruner, direction="maximize", load_if_exists=True)


    # make objective
    objective = make_objective_with_mlflow(
        X, Y, mask,
        cross_validate_fn=cross_validate_fn,
        n_splits=n_splits,
        seed=seed,
        base_model_ctor_kwargs=base_model_ctor_kwargs,
        base_fit_kwargs=base_fit_kwargs,
        base_dataloader_kwargs=base_dataloader_kwargs,
        mlflow_experiment=mlflow_experiment,
        log_models_as_artifacts=log_models_as_artifacts,
        optuna_progress_dir=optuna_progress_dir
    )

    # prepare progress directory and callback
    progress_dir = "optuna_progress"
    os.makedirs(progress_dir, exist_ok=True)
    cb = lambda study, trial: save_study_progress(study, trial, out_dir=progress_dir)

    # Optimize. If you want parallel runs, launch multiple processes each calling this same function.
    study.optimize(objective, n_trials=n_trials, callbacks=[cb])

    # once finished, save final trials CSV
    save_study_progress(study, None, out_dir=progress_dir)

    print("Optimization finished.")
    print("Best trial:")
    print(" Value:", study.best_value)
    print(" Params:")
    for k, v in study.best_params.items():
        print(f" {k}: {v}")
    return study

# Обучение бейзлайна

In [None]:
class BaselineGRU(nn.Module):
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        output_size: int,
        num_layers: int,
        dropout: float = 0.0,
        in_layernorm: bool = True,
        out_layernorm: bool = True
    ):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        self.dropout = dropout
        self.in_layernorm = in_layernorm
        self.out_layernorm = out_layernorm

        # входная проекция (опционально) + LayerNorm
        if self.in_layernorm:
            self.input_proj = nn.Linear(self.input_size, self.hidden_size)
            self.input_ln = nn.LayerNorm(self.hidden_size)

        gru_in_size = self.hidden_size if self.in_layernorm else self.input_size
        self.gru = nn.GRU(
            input_size=gru_in_size,
            hidden_size=self.hidden_size,
            num_layers=self.num_layers,
            batch_first=True,
            dropout=self.dropout if self.num_layers > 1 else 0.0,
            bidirectional=False,
        )

        if self.out_layernorm:
            self.out_ln = nn.LayerNorm(self.hidden_size)

        self.head = nn.Linear(self.hidden_size, self.output_size)

        # def _init_weights(self):
        #     for name, param in self.named_parameters():
        #         if 'weight' in name:
        #             nn.init.kaiming_normal_(param)
        #         elif 'bias' in name:
        #             nn.init.zeros_(param)
        nn.init.xavier_uniform_(self.head.weight)
        if self.head.bias is not None:
            nn.init.zeros_(self.head.bias)


    def forward(self, x: torch.Tensor, hx: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        x: (B, T, F)
        returns:
        out: (B, T, C)
        h_n: (num_layers, B, H)
        """
        if self.in_layernorm:
            x = self.input_proj(x)
            x = self.input_ln(x)
        out, h_n = self.gru(x, hx)  # out: (B, T, H); hx-начальное скрытое состояние
        if self.out_layernorm:
            out = self.out_ln(out)
        preds = self.head(out)  # (B, T, C)
        return preds, h_n

In [None]:
eval_data['seq_ix'] += 11000
tot_data = pd.concat([train_data, eval_data]).reset_index(drop=True)
tot_data = tot_data.sample(len(tot_data))

In [None]:
X_tot, Y_tot, mask_tot = df_to_numpy_arrays(tot_data.sample(len(tot_data)))

In [None]:
model_params={
    "input_size": X_tot.shape[2],
    "hidden_size": 128,
    "output_size": 2,
    "num_layers": 3,
    "dropout": 0.05,
    "in_layernorm": True,
    "out_layernorm": True
}

dataloader_params={
    "batch_size": BATCH_SIZE,
    "num_workers": 4,
    "pin_memory": True,
    "prefetch_factor": 2
}

fit_params={
    "num_epochs": 25,
    "lr": 3e-4,
    "weight_decay": 1e-5,
    "alpha": 1.0,
    "beta": 0.5,
    "weight_gamma": 1.0,
    "grad_clip": 2.0,
    "accum_steps": 1,
    "use_amp": False,
    "monitor": "weighted_pearson",
    "maximize_monitor": True,
    "log_dir": "./experiments/first_trial"
}

fold_results, summary = cross_validate(
    X=X_tot,
    Y=Y_tot,
    mask=mask_tot,
    n_splits=5,
    seed=42,
    model_ctor_kwargs=model_params,
    fit_kwargs=fit_params,
    dataloader_kwargs=dataloader_params,
    out_dir="cv_results"
)

In [None]:
# model_params={
#     "input_size": X_tot.shape[2],
#     "hidden_size": 128,
#     "output_size": 2,
#     "num_layers": 3,
#     "dropout": 0.05,
#     "in_layernorm": True,
#     "out_layernorm": True
# }

dataloader_params={
    "num_workers": 4,
    "pin_memory": True,
    "prefetch_factor": 2
}

fit_params={
    "num_epochs": 25,
    "grad_clip": 1.0,
    "accum_steps": 1,
    "use_amp": False,
    "monitor": "weighted_pearson",
    "maximize_monitor": True,
    # "log_dir": "./experiments/first_trial",
    "patience": 5
}

study = run_optuna_search(
    X_tot, Y_tot, mask_tot,
    cross_validate_fn=cross_validate,
    n_trials=200,
    n_splits=3,
    seed=42,
    study_name="my_optuna_study",
    storage_url="sqlite:///optuna_study.db",
    base_fit_kwargs=fit_params,
    base_dataloader_kwargs=dataloader_params,
    mlflow_experiment="my_mlflow_experiment",
    n_jobs=1,
    log_models_as_artifacts=False,  # True may upload large artifacts
)

# Экспорт результатов

In [None]:
from pathlib import Path
def export_gru_to_onnx(
    pytorch_model: nn.Module,
    input_size: int,
    hidden_size: int,
    num_layers: int,
    dropout: float,
    in_layernorm: bool,
    out_layernorm: bool,
    onnx_path: str,
    opset_version: int = 12,
    verbose: bool = False,
):
    """
    Exports a PyTorch recurrent model (GRU-based) to ONNX.
    Forward must return (preds, h_n) where preds shape (B, T, C) and h_n shape (num_layers, B, H).
    We export for B=1, T=1 and set dynamic axes on batch/time.
    """
    pytorch_model.eval()
    x_dummy = torch.randn(1, 1, input_size, dtype=torch.float32)  # (B=1, T=1, F)
    h0_dummy = torch.zeros(num_layers, 1, hidden_size, dtype=torch.float32)  # (num_layers, B=1, H)

    input_names = ["x", "h0"]
    output_names = ["preds", "h_n"]

    dynamic_axes = {
        "x": {0: "batch", 1: "seq"},
        "h0": {1: "batch"},
        "preds": {0: "batch", 1: "seq"},
        "h_n": {1: "batch"},
    }

    torch.onnx.export(
        pytorch_model,
        (x_dummy, h0_dummy),
        onnx_path,
        input_names=input_names,
        output_names=output_names,
        dynamic_axes=dynamic_axes,
        opset_version=opset_version,
        do_constant_folding=True,
        verbose=verbose,
    )
    print(f"Exported ONNX model -> {onnx_path}")


def export_folds_to_onnx(
    checkpoint_paths,
    onnx_out_dir,
    input_size,
    hidden_size,
    output_size,
    num_layers,
    dropout,
    in_layernorm,
    out_layernorm
):
    os.makedirs(onnx_out_dir, exist_ok=True)
    for i, ckpt_path in enumerate(checkpoint_paths, start=1):
        model = BaselineGRU(
            input_size=input_size,
            hidden_size=hidden_size,
            output_size=output_size,
            num_layers=num_layers,
            dropout=dropout,
            in_layernorm=in_layernorm,
            out_layernorm=out_layernorm
        )
        ckpt = torch.load(ckpt_path, map_location="cpu")
        model.load_state_dict(ckpt["model_state_dict"])
        model.eval()
        onnx_path = Path(onnx_out_dir) / f"fold{i}.onnx"
        export_gru_to_onnx(
            model,
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            dropout=dropout,
            in_layernorm=in_layernorm,
            out_layernorm=out_layernorm,
            onnx_path=str(onnx_path),
        )

In [None]:
export_folds_to_onnx(
    checkpoint_paths=['/home/jovyan/optuna_trials/trial_95_20260301_035026/models/best_model_fold1.pt', '/home/jovyan/optuna_trials/trial_95_20260301_035026/models/best_model_fold2.pt'],
    onnx_out_dir='./onxx_models',
    input_size=X_tot.shape[2],
    hidden_size=256,
    num_layers=4,
    output_size=2,
    dropout=0,
    in_layernorm=False,
    out_layernorm=True
)

In [None]:
import sys
sys.path.insert(0,'./competition_package')
from utils import DataPoint, ScorerStepByStep

In [None]:
import glob
class PredictionModelONNX:
    def __init__(
        self,
        onnx_path: str,
        feat_center: Optional[np.ndarray] = None,
        feat_scale: Optional[np.ndarray] = None,
    ):
        self.sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
        ins = self.sess.get_inputs()
        outs = self.sess.get_outputs()
        self.input_name_x = ins[0].name
        self.input_name_h0 = ins[1].name if len(ins) > 1 else None
        self.output_name_preds = outs[0].name
        self.output_name_hn = outs[1].name if len(outs) > 1 else None

        self.feat_center = feat_center.astype(np.float32) if feat_center is not None else None
        self.feat_scale = feat_scale.astype(np.float32) if feat_scale is not None else None

        # per-sequence hidden states store: seq_ix -> h_n (num_layers, 1, hidden_size)
        self.hidden_states: Dict[int, np.ndarray] = {}

    def reset_sequence(self, seq_ix: int):
        if seq_ix in self.hidden_states:
            del self.hidden_states[seq_ix]

    def predict(self, data_point: DataPoint) -> Optional[np.ndarray]:
        if not data_point.need_prediction:
            # reset sequence on first step if you want; leave to caller if needed
            return None

        seq = data_point.seq_ix
        step = data_point.step_in_seq
        x_raw = np.asarray(data_point.state, dtype=np.float32)  # (F,)

        # reset hidden on sequence start (support step starting at 0 or 1)
        if seq not in self.hidden_states or step in (0, 1):
            # infer expected shapes from session input if available
            h0_inp = None
            for inp in self.sess.get_inputs():
                if inp.name == self.input_name_h0:
                    h0_inp = inp
                    break
            if h0_inp is not None:
                proto = h0_inp.shape  # may contain None/strings
                # attempt to read numeric entries, fallback if missing
                try:
                    num_layers = int(proto[0]) if proto[0] not in (None, "batch", "N", "") else 3
                    hidden_size = int(proto[2]) if proto[2] not in (None, "batch", "H", "") else 256
                except Exception:
                    num_layers, hidden_size = 3, 256
            else:
                num_layers, hidden_size = 3, 256

            h0 = np.zeros((num_layers, 1, hidden_size), dtype=np.float32)
            self.hidden_states[seq] = h0
        else:
            h0 = self.hidden_states[seq]

        # preprocess features
        if self.feat_center is not None and self.feat_scale is not None:
            x_proc = (x_raw - self.feat_center) / self.feat_scale
        else:
            x_proc = x_raw

        x_in = x_proc.reshape(1, 1, -1).astype(np.float32)  # (1,1,F)
        h0_in = h0.astype(np.float32)

        input_feed = {self.input_name_x: x_in}
        if self.input_name_h0 is not None:
            input_feed[self.input_name_h0] = h0_in

        outputs = self.sess.run([self.output_name_preds, self.output_name_hn], input_feed)
        preds_np = outputs[0]  # (1,1,C)
        h_n_np = outputs[1]  # (num_layers,1,H)

        self.hidden_states[seq] = h_n_np
        return preds_np.reshape(-1)  # (C,)


class EnsemblePredictionModelONNX:
    def __init__(
        self,
        onnx_paths: list,
        feat_centers: Optional[list] = None,
        feat_scales: Optional[list] = None,
    ):
        assert len(onnx_paths) >= 1
        self.models = []
        for i, p in enumerate(onnx_paths):
            center = None if feat_centers is None else feat_centers[i]
            scale = None if feat_scales is None else feat_scales[i]
            pm = PredictionModelONNX(p, feat_center=center, feat_scale=scale)
            self.models.append(pm)

    def reset_sequence(self, seq_ix: int):
        for m in self.models:
            m.reset_sequence(seq_ix)

    def predict(self, data_point: DataPoint) -> Optional[np.ndarray]:
        if not data_point.need_prediction:
            # reset per-model sequence on first step
            if data_point.step_in_seq in (0, 1):
                for m in self.models:
                    m.reset_sequence(data_point.seq_ix)
            return None

        preds = []
        for m in self.models:
            p = m.predict(data_point)
            if p is None:
                # unexpected; fallback zeros
                # determine output size from first model outputs if possible
                if len(self.models) > 0:
                    out_size = len(self.models[0].predict(data_point) or [0, 0])
                else:
                    out_size = 2
                p = np.zeros(out_size, dtype=np.float32)
            preds.append(p.astype(np.float32))

        preds = np.stack(preds, axis=0)  # (M, C)
        mean_preds = preds.mean(axis=0)  # (C,)
        return mean_preds

In [None]:
import glob
import onnxruntime as ort
onnx_paths = sorted(glob.glob("./onxx_models/fold*.onnx"))
ensemble = EnsemblePredictionModelONNX(onnx_paths, None, None)

In [None]:
scorer = ScorerStepByStep("./competition_package/datasets/valid.parquet")
results = scorer.score(ensemble)