# Colab setup

In [None]:
import sys
! {sys.executable} -m pip install pytorch-lifestream
! {sys.executable} -m pip install umap-learn
! {sys.executable} -m pip install catboost

# CoLES-demo-multimodal

**In this demo, we will try to show how Multimodal CoLES handles event data of different modalities.**

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
import os

if not os.path.exists("lightning_logs/CoLES-demo-multimodal"):
    !mkdir -p lightning_logs/CoLES-demo-multimodal

if not os.path.exists("CatBoostClassifier"):
    !mkdir -p CatBoostClassifier

if not os.path.exists("model"):
    !mkdir -p model

# Libraries

In [None]:
! pip install pytorch-lifestream

In [None]:
from functools import partial
from datetime import timedelta
from time import time

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns
import catboost
import umap

import torch
import pytorch_lightning as pl
from torch.utils.data.dataloader import DataLoader
from pytorch_lightning.loggers import TensorBoardLogger

from sklearn.model_selection import train_test_split

from ptls.nn import TrxEncoder
from ptls.nn.seq_encoder.rnn_encoder import RnnEncoder
from ptls.frames import PtlsDataModule
from ptls.frames.coles import CoLESModule
from ptls.frames.coles.split_strategy import SampleSlices
from ptls.frames.coles.multimodal_dataset import MultiModalDataset
from ptls.frames.coles.multimodal_dataset import MultiModalIterableDataset
from ptls.frames.coles.multimodal_dataset import MultiModalSortTimeSeqEncoderContainer
from ptls.frames.coles.multimodal_inference_dataset import MultiModalInferenceDataset
from ptls.frames.coles.multimodal_inference_dataset import MultiModalInferenceIterableDataset
from ptls.frames.inference_module import InferenceModuleMultimodal
from ptls.data_load.iterable_processing import SeqLenFilter
from ptls.data_load import IterableProcessingDataset
from ptls.data_load.utils import collate_feature_dict
from ptls.data_load.datasets import MemoryMapDataset
from ptls.preprocessing import PandasDataPreprocessor

# Working with data

## Data load

In [None]:
transactions = pd.read_csv("https://huggingface.co/datasets/dllllb/transactions-gender/resolve/main/transactions.csv.gz?download=true", compression="gzip")
targets = pd.read_csv("https://huggingface.co/datasets/dllllb/transactions-gender/resolve/main/gender_train.csv?download=true")

In [None]:
transactions = transactions.dropna().reset_index(drop=True)
transactions

Here
1. `customer_id` is the id of some user
2. `tr_datetime` is the time of the transaction
3. `mcc_code` is, in fact, the mcc code of the transaction
4. `tr_type` is the type of transaction (what was paid for)
5. `amount` is the amount of the transaction
6. `term-id` is the id of the terminal where the transaction was carried out

We will predict the gender of the user based on his transactions.

In [None]:
targets

In [None]:
n_cutomers = len(pd.unique(transactions["customer_id"]))
n_labeling_cutomers = len(pd.unique(targets["customer_id"]))

print("n_cutomers:", n_cutomers)
print("n_labeling_cutomers:", n_labeling_cutomers)

In [None]:
list(transactions.columns)

In [None]:
sourceA = transactions[["customer_id", "tr_datetime", "mcc_code", "term_id"]]
sourceB = transactions[["customer_id", "tr_datetime", "tr_type", "amount"]]

In [None]:
sourceA

In [None]:
sourceB

In [None]:
sourceA_drop_indices = np.random.choice(sourceA.index, 130000, replace=False)
sourceB_drop_indices = np.random.choice(sourceB.index, 420000, replace=False)

sourceA = sourceA.drop(sourceA_drop_indices).reset_index(drop=True)
sourceB = sourceB.drop(sourceB_drop_indices).reset_index(drop=True)

In [None]:
sourceA

In [None]:
sourceB

## Preprocessing

In [None]:
mcc_code_in = len(np.unique((sourceA["mcc_code"])))
term_id_in = len(np.unique((sourceA["term_id"])))
tr_type_in = len(np.unique((sourceB["tr_type"])))

print("mcc_code_in:", mcc_code_in)
print("term_id_in:", term_id_in)
print("tr_type_in", tr_type_in)

In [None]:
def tr_datetime_preprocess(tr_datetime):
    days, hms = tr_datetime.split()
    hh, mm, ss = hms.split(":")

    seconds = timedelta(hours=int(hh), minutes=int(mm), seconds=int(ss))
    seconds = seconds.total_seconds()
    seconds += int(days) * 24 * 3600

    return int(seconds)

In [None]:
sourceA["tr_datetime"] = sourceA["tr_datetime"].apply(tr_datetime_preprocess)
sourceB["tr_datetime"] = sourceB["tr_datetime"].apply(tr_datetime_preprocess)

In [None]:
sourceA_preprocessor = PandasDataPreprocessor(
    col_id="customer_id",
    col_event_time="tr_datetime",
    event_time_transformation="none",
    cols_category=["mcc_code", "term_id"],
    return_records=False,
)

sourceB_preprocessor = PandasDataPreprocessor(
    col_id="customer_id",
    col_event_time="tr_datetime",
    event_time_transformation="none",
    cols_numerical=["tr_type", "amount"],
    return_records=False,
)

In [None]:
processed_sourceA = sourceA_preprocessor.fit_transform(sourceA)
processed_sourceB = sourceB_preprocessor.fit_transform(sourceB)

In [None]:
processed_sourceA.columns = [
    "sourceA_" + str(col) if str(col) != "customer_id" else str(col)
    for col in processed_sourceA.columns
]

In [None]:
processed_sourceB.columns = [
    "sourceB_" + str(col) if str(col) != "customer_id" else str(col)
    for col in processed_sourceB.columns
]

In [None]:
joined_data = processed_sourceA.merge(processed_sourceB, how="outer", on="customer_id")

In [None]:
joined_data = joined_data.applymap(lambda x: torch.tensor([]) if pd.isna(x) else x)

In [None]:
train_df, test_df = train_test_split(joined_data,
                                     test_size=0.1,
                                     random_state=42)
train_df, valid_df = train_test_split(train_df,
                                      test_size=0.1,
                                      random_state=42)

In [None]:
train_df = train_df.reset_index(drop=True)
valid_df = valid_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)

In [None]:
train_dict = train_df.to_dict("records")
valid_dict = valid_df.to_dict("records")
test_dict = test_df.to_dict("records")

In [None]:
source_features = {
    "sourceA": ["event_time", "mcc_code", "term_id"],
    "sourceB": ["event_time", "tr_type", "amount"]
}

In [None]:
splitter = SampleSlices(split_count=5, cnt_min=25, cnt_max=50)

In [None]:
train_multimodal_data = MultiModalIterableDataset(
    data = train_dict,
    splitter = splitter,
    source_features = source_features,
    col_id = "customer_id",
    col_time = "event_time",
    source_names = ("sourceA", "sourceB")
)

valid_multimodal_data = MultiModalIterableDataset(
    data = valid_dict,
    splitter = splitter,
    source_features = source_features,
    col_id = "customer_id",
    col_time = "event_time",
    source_names = ("sourceA", "sourceB")
)

In [None]:
train_loader = PtlsDataModule(
    train_data = train_multimodal_data,
    train_num_workers = 16,
    train_batch_size = 64,

    valid_data = valid_multimodal_data
)

In [None]:
sourceA_encoder_params = dict(
    embeddings_noise = 0.003,
    linear_projection_size = 64,
    embeddings = {
        "mcc_code": {"in": mcc_code_in, "out": 32},
        "term_id": {"in": term_id_in, "out": 32}
    },
)

sourceB_encoder_params = dict(
    embeddings_noise = 0.003,
    linear_projection_size = 64,
    embeddings = {
        "tr_type": {"in": tr_type_in, "out": 32},
    },
    numeric_values = {"amount": "identity"},
)

In [None]:
sourceA_encoder = TrxEncoder(**sourceA_encoder_params)
sourceB_encoder = TrxEncoder(**sourceB_encoder_params)

In [None]:
seq_encoder = MultiModalSortTimeSeqEncoderContainer(
    trx_encoders = {
        "sourceA": sourceA_encoder,
        "sourceB": sourceB_encoder,
    },

    input_size = 64,
    hidden_size = 256,
    seq_encoder_cls = RnnEncoder,
    type = "gru"
)

In [None]:
model = CoLESModule(
    seq_encoder = seq_encoder,
    optimizer_partial = partial(torch.optim.Adam, lr=0.004),
    lr_scheduler_partial = partial(torch.optim.lr_scheduler.StepLR, step_size=30, gamma=0.5)
)

In [None]:
logger = TensorBoardLogger("lightning_logs", name="CoLES-demo-multimodal")

pl_trainer = pl.Trainer(
    logger = logger,
    max_epochs = 1,
    accelerator = "gpu",
    devices = 1,
    enable_progress_bar = True
)

In [None]:
pl_trainer.fit(model, train_loader)

In [None]:
inf_test_data = MultiModalInferenceIterableDataset(
    data = test_dict,
    source_features = source_features,
    col_id = "customer_id",
    col_time = "event_time",
    source_names = ("sourceA", "sourceB")
)

In [None]:
inf_test_loader = DataLoader(
    dataset = inf_test_data,
    collate_fn = partial(inf_test_data.collate_fn, col_id="customer_id"),
    shuffle = False,
    num_workers = 0,
    batch_size = 8
)

In [None]:
inf_train_data = MultiModalInferenceIterableDataset(
    data = train_dict,
    source_features = source_features,
    col_id = "customer_id",
    col_time = "event_time",
    source_names = ("sourceA", "sourceB")
)

In [None]:
inf_train_loader = DataLoader(
    dataset = inf_train_data,
    collate_fn = partial(inf_train_data.collate_fn, col_id="customer_id"),
    shuffle = False,
    num_workers = 0,
    batch_size = 8
)

In [None]:
# inf_test_embeddings

## Tunning hyperparams using RankMe metric

In [None]:
!git clone https://github.com/google-research/google-research.git

In [None]:
!ls google-research/graph_embedding/metrics

In [None]:
import sys
sys.path.append("google-research/graph_embedding/metrics")

In [None]:
from metrics import (rankme,
        coherence,
        pseudo_condition_number,
        alpha_req,
        stable_rank,
        ne_sum,
        self_clustering)

In [None]:
from itertools import product

In [None]:
!pip install git+https://github.com/simonzhang00/ripser-plusplus.git

In [None]:
!ls 

In [None]:
import ripserplusplus as rpp
def ripser_metric(embeddings):
    """Вычисление метрики на основе ripserplusplus."""
    start_time = time()
    
    if not isinstance(embeddings, np.ndarray):
        embeddings = np.array(embeddings)

    diagrams = rpp.run("--format point-cloud", embeddings)

    persistence_sum = sum([birth - death for birth, death in diagrams[0] if death > birth])

    elapsed_time = time() - start_time

    return persistence_sum, elapsed_time

In [None]:
# batch_sizes = [64, 128]
# learning_rates = [0.001, 0.004]
# split_counts = [3, 5]
# cnt_min_values = [10, 25]
# cnt_max_values = [50, 100]

# # Генерация сетки гиперпараметров
# hyperparameter_grid = [
#     {
#         "batch_size": batch_size,
#         "learning_rate": lr,
#         "split_count": split_count,
#         "cnt_min": cnt_min,
#         "cnt_max": cnt_max,
#     }
#     for batch_size, lr, split_count, cnt_min, cnt_max in product(
#         batch_sizes, learning_rates, split_counts, cnt_min_values, cnt_max_values
#     )
# ]

In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

In [None]:
checkpoints_path = "checkpoints"
os.makedirs(checkpoints_path, exist_ok=True)

In [None]:
def create_datasets(train_dict, valid_dict, params, source_features):
    splitter = SampleSlices(
        split_count=params["split_count"],
        cnt_min=params["cnt_min"],
        cnt_max=params["cnt_max"],
    )

    train_data = MultiModalIterableDataset(
        data=train_dict,
        splitter=splitter,
        source_features=source_features,
        col_id="customer_id",
        col_time="event_time",
        source_names=("sourceA", "sourceB"),
    )

    valid_data = MultiModalIterableDataset(
        data=valid_dict,
        splitter=splitter,
        source_features=source_features,
        col_id="customer_id",
        col_time="event_time",
        source_names=("sourceA", "sourceB"),
    )

    data_loader = PtlsDataModule(
        train_data=train_data,
        train_batch_size=params["batch_size"],
        train_num_workers=0,
        valid_data=valid_data,
    )

    return data_loader

In [None]:

import glob

In [None]:
from sklearn.metrics import top_k_accuracy_score


def evaluate_model(model, pl_trainer, inf_loader, selected_metrics=None, topk=5):
    model.eval()
    metrics, times, inf_test_embeddings = compute_metrics(model, pl_trainer, inf_loader, selected_metrics)
    targets_df = targets.set_index("customer_id")
    inf_test_df = inf_test_embeddings.merge(targets_df, how="inner", on="customer_id").set_index("customer_id")
    
    X = inf_test_df.drop(columns=["gender"])
    y = inf_test_df["gender"]
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
    
    classifier = catboost.CatBoostClassifier(
        iterations=150,
        random_seed=42,
        verbose=0,
    )
    classifier.fit(X_train, y_train)
    
    accuracy = classifier.score(X_test, y_test)
    
    return metrics, times, accuracy

In [None]:
def compute_metrics(model, pl_trainer, inf_loader, selected_metrics=None):
    """
    Compute selected embedding quality metrics.

    Parameters:
        model (torch.nn.Module): Trained model to evaluate.
        pl_trainer (pl.Trainer): PyTorch Lightning trainer for inference.
        inf_test_loader (DataLoader): DataLoader for inference.
        selected_metrics (list, optional): List of metric names to compute. 
                                           If None, computes all available metrics.

    Returns:
        dict: Computed metrics.
        dict: Computed times for each metric.
        pd.DataFrame: Extracted embeddings.
    """
    model.eval()
    
    inference_module = InferenceModuleMultimodal(
        model=model,
        pandas_output=True,
        drop_seq_features=True,
        model_out_name="emb",
        col_id="customer_id",
    )
    inference_module.model.is_reduce_sequence = True
    device = next(model.parameters()).device

    inf_embeddings = pd.concat(
        pl_trainer.predict(inference_module, inf_loader),
        axis=0,
    )
    embeddings = torch.tensor(inf_embeddings.drop(columns=["customer_id"]).values, device=device)
    embeddings_np = embeddings.cpu().numpy().astype(np.float32)

    u, s, _ = np.linalg.svd(embeddings_np, compute_uv=True, full_matrices=False)

    available_metrics = {
        "rankme": rankme,
        "coherence": coherence,
        "pseudo_condition_number": pseudo_condition_number,
        "alpha_req": alpha_req,
        "stable_rank": stable_rank,
        "ne_sum": ne_sum,
        "self_clustering": self_clustering,
        "ripser": ripser_metric
    }

    if selected_metrics is None:
        selected_metrics = list(available_metrics.keys())

    metrics = {}
    times = {}

    for metric_name in selected_metrics:
        if metric_name in available_metrics:
            if metric_name == "ripser":
                metrics[metric_name], times[metric_name] = available_metrics[metric_name](embeddings_np)
            else:
                start_time = time()
                metrics[metric_name] = available_metrics[metric_name](embeddings_np, u=u, s=s)
                times[metric_name] = time() - start_time

            print(f"Computed {metric_name} in {times[metric_name]:.4f} seconds")
        else:
            print(f"metric {metric_name} not is in available metrics. Choose one of these: {available_metrics.keys()}")

    return metrics, times, inf_embeddings

In [None]:
fixed_params = {
    "batch_size": 64,
    "learning_rate": 0.001,
    "split_count": 3,
    "cnt_min": 10,
    "cnt_max": 50,
    "embedding_dim": 16,  
    "category_embedding_dim": 8,  
    "hidden_size": 128,  
}


variable_params = {
    "batch_size": [16, 32, 64, 128], 
    "learning_rate": [0.0001, 0.001, 0.005, 0.01, 0.05],
    "split_count": [2, 3, 5, 7],
    "cnt_min": [5, 10, 15, 20, 25],
    "cnt_max": [50, 80, 100, 150],
    "embedding_dim": [8, 16, 24, 32],
    "category_embedding_dim": [4, 8, 12, 16, 24],
    "hidden_size": [32, 64, 128, 192, 256, 1024],
}


all_hyperparameter_grids = []
for variable_param_name, variable_param_values in variable_params.items():
    for value in variable_param_values:
        hyperparameter_grid = {**fixed_params, variable_param_name: value}
        all_hyperparameter_grids.append((variable_param_name, hyperparameter_grid))


In [None]:
# fixed_params = {
#     "batch_size": 64,
#     "learning_rate": 0.001,
#     "split_count": 3,
#     "cnt_min": 10,
#     "cnt_max": 50,
#     "embedding_dim": 16,  
#     "category_embedding_dim": 8,  
# }

# hidden_sizes = [64, 256, 512, 1024, 1696, 2048, 2424]

# all_hyperparameter_grids = [
#     {**fixed_params, "hidden_size": h_size} for h_size in hidden_sizes
# ]

In [None]:
num_epochs = 30
output_csv = "hidden_size_results.csv"
columns = [
    *fixed_params.keys(), "checkpoint", "epoch_num", "accuracy", "early_stop_epoch", "hidden_size",
] + [
    "metric_" + key for key in [
        "rankme", "coherence", "pseudo_condition_number", 
        "alpha_req", "stable_rank", "ne_sum", "self_clustering", "ripser"
    ]
] + ["time_" + key for key in [
    "rankme", "coherence", "pseudo_condition_number", 
    "alpha_req", "stable_rank", "ne_sum", "self_clustering", "ripser"
]]

In [None]:
class CustomLogger(pl.Callback):
    def __init__(self):
        super().__init__()
        self.early_stopping_epoch = None
    
    def on_train_epoch_end(self, trainer, pl_module):
        train_loss = trainer.callback_metrics.get("train_loss", None)
        val_loss = trainer.callback_metrics.get("val_loss", None)
        
        if train_loss is not None and val_loss is not None:
            print(f"Epoch {trainer.current_epoch}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
        
        if trainer.early_stopping_callback is not None and trainer.early_stopping_callback.wait_count == 0:
            self.early_stopping_epoch = trainer.current_epoch


custom_logger = CustomLogger()
early_stopping_callback = EarlyStopping(
    monitor="val_loss",
    patience=5,
    mode="min",
    verbose=True
)

In [None]:
import gc

In [None]:
! rm -rf checkpoints

In [None]:
cur_time = time()

for param in all_hyperparameter_grids:
    
    
    params = param
    

    train_loader = create_datasets(train_dict, valid_dict, params, source_features)

    sourceA_encoder_params = dict(
        embeddings_noise=0.003,
        linear_projection_size=64,
        embeddings={
            "mcc_code": {"in": mcc_code_in, "out": 32},
            "term_id": {"in": term_id_in, "out": 32},
        },
    )
    
    sourceB_encoder_params = dict(
        embeddings_noise=0.003,
        linear_projection_size=64,
        embeddings={
            "tr_type": {"in": tr_type_in, "out": 32},
        },
        numeric_values={"amount": "identity"},
    )
    
    sourceA_encoder = TrxEncoder(**sourceA_encoder_params)
    sourceB_encoder = TrxEncoder(**sourceB_encoder_params)
    
    seq_encoder = MultiModalSortTimeSeqEncoderContainer(
        trx_encoders={
            "sourceA": sourceA_encoder,
            "sourceB": sourceB_encoder,
        },
        input_size=64,
        hidden_size=params["hidden_size"],  
        seq_encoder_cls=RnnEncoder,
        type="gru",
    )

    model = CoLESModule(
        seq_encoder=seq_encoder,
        optimizer_partial=partial(torch.optim.Adam, lr=params["learning_rate"]),
        lr_scheduler_partial=partial(torch.optim.lr_scheduler.StepLR, step_size=10, gamma=0.5),
    )

    early_stopping_callback = EarlyStopping(
        monitor="loss",
        patience=5,
        mode="min",
        verbose=True
    )

    checkpoint_callback = ModelCheckpoint(
        dirpath=checkpoints_path,
        filename=f"model_{params['batch_size']}_{params['learning_rate']}_{params['split_count']}_{params['cnt_min']}_{params['cnt_max']}_{params['hidden_size']}{{epoch:02d}}",
        save_top_k=-1,
        every_n_epochs=1,
    )

    
    pl_trainer = pl.Trainer(
        callbacks=[checkpoint_callback, early_stopping_callback, custom_logger],
        default_root_dir=checkpoints_path,
        check_val_every_n_epoch=1,
        max_epochs=num_epochs,
        accelerator="gpu",
        devices=1,
        enable_progress_bar=True,
        precision=16
    )
    model.train()
    pl_trainer.fit(model, train_loader)

    early_stop_epoch = custom_logger.early_stopping_epoch
    if early_stop_epoch is None:
        early_stop_epoch = num_epochs

    
    checkpoint_files = glob.glob(f"{checkpoints_path}/model_{params['batch_size']}_{params['learning_rate']}_{params['split_count']}_{params['cnt_min']}_{params['cnt_max']}_{params['hidden_size']}*.ckpt")
    checkpoint_files.sort()
    print(f"Elapsed time: {time() - cur_time:.2f} seconds")

    print(f'Early stop is {early_stop_epoch}')

    for i, checkpoint in enumerate(checkpoint_files):
        print(f"Processing checkpoint number {i}")
        model = CoLESModule.load_from_checkpoint(checkpoint, seq_encoder=seq_encoder)

        
        metrics, times, accuracy = evaluate_model(model, checkpoint)
        metrics_flattened = {f"metric_{k}": round(v, 4) for k, v in metrics.items()}
        times_flattened = {f"time_{k}": round(v, 4) for k, v in times.items()}

        
        new_result = {
            **params,
            "checkpoint": checkpoint,
            "epoch_num": int(i),
            "accuracy": accuracy,
            **metrics_flattened,
            **times_flattened,
            "early_stop_epoch": int(early_stop_epoch)
        }

        
        results = pd.DataFrame([new_result], columns=columns)
        print('----------')
        print(results["early_stop_epoch"])

        if not os.path.exists(output_csv):  
            pd.DataFrame(columns=columns).to_csv(output_csv, mode="w", index=False, header=True)
        
        results.to_csv(output_csv, mode="a", header=False, index=False)

        del metrics, accuracy, new_result
        torch.cuda.empty_cache()
        gc.collect()

    print(f"Removing checkpoints for parameters: {params}")
    for checkpoint in checkpoint_files:
        os.remove(checkpoint)

    del model
    del train_loader
    torch.cuda.empty_cache()
    gc.collect()

print("Optimization complete!")

In [None]:
import optuna
import pandas as pd
import torch
import gc
from tqdm import tqdm

# Файл для сохранения результатов
output_csv = "optuna_hyperparameter_by_topK.csv"

# Определение диапазонов гиперпараметров
def define_search_space(trial):
    return {
        "batch_size": trial.suggest_categorical("batch_size", [32, 64]),
        "learning_rate": trial.suggest_loguniform("learning_rate", 1e-4, 5e-2),
        "hidden_size": trial.suggest_categorical("hidden_size", [64, 128]),
        "embedding_dim": trial.suggest_categorical("embedding_dim", [8, 16, 32]),
        "category_embedding_dim": trial.suggest_categorical("category_embedding_dim", [4, 8, 16]),
        "split_count": trial.suggest_categorical("split_count", [2, 3, 5]),
        "cnt_min": trial.suggest_categorical("cnt_min", [5, 10, 20]),
        "cnt_max": trial.suggest_categorical("cnt_max", [50, 80, 100]),
    }


metric_names = [
    "rankme", "coherence", "pseudo_condition_number",
    "alpha_req", "stable_rank", "ne_sum", "self_clustering", "ripser"
]


optuna_results = []

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [None]:
from ptls.data_load.padded_batch import PaddedBatch


class CustomCoLESModule(CoLESModule):
    def __init__(self, custom_metric_name, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.custom_metric_name = custom_metric_name
        


    def validation_step(self, batch, batch_idx):
        print("valedation step")
        x, y = batch
        y = y.to(self.device)
        for key in x:
            if isinstance(x[key], PaddedBatch):
                
                x[key] = PaddedBatch(
                    payload={k: v.to(self.device) for k, v in x[key].payload.items()},
                    length=x[key].length.to(self.device)  
                )
            else:
                print(f"⚠️ [WARNING] Expected PaddedBatch but got {type(x[key])} for {key}")

        print(f"Model is on device: {next(self.parameters()).device}")
        print(f"x is on device: {[x[k].device for k in x]}")
        print(f"y is on device: {y.device}")

        y_hat = self(x)

        
        loss = torch.nn.functional.cross_entropy(y_hat, y)

        metric_value, _, _ = compute_metrics(model, pl_trainer, inf_test_loader, selected_metrics=[self.custom_metric_name])

        print(f"[DEBUG] Logging metric: valid/{self.custom_metric_name} = {metric_value[self.custom_metric_name]}")

        self.trainer.logger.log_metrics({f"valid/{self.custom_metric_name}": metric_value[self.custom_metric_name]}, step=self.current_epoch)

        return {"loss": loss, self.custom_metric_name: metric_value[self.custom_metric_name]}

In [None]:
! rm -rf /kaggle/working/checkpoints

In [None]:
! rm /kaggle/working/optuna_best_trials_accuracy.csv

In [None]:
import optuna
import time
import pandas as pd
import torch
import gc
import os
import glob
from functools import partial
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from tqdm import tqdm
from time import time

# Файл для логирования результатов
output_csv = "optuna_results.csv"

# Метрики для проверки
metric_names = [
    "rankme", "coherence", "pseudo_condition_number",
    "alpha_req", "stable_rank", "ne_sum", "self_clustering", "ripser"
]

optuna_columns = [
    *fixed_params.keys(), "checkpoint", "epoch_num", "accuracy", "topk_accuracy", "early_stop_epoch", "hidden_size",
] + [
    "metric_" + key for key in [
        "rankme", "coherence", "pseudo_condition_number", 
        "alpha_req", "stable_rank", "ne_sum", "self_clustering", "ripser"
    ]
] + ["time_" + key for key in [
    "rankme", "coherence", "pseudo_condition_number", 
    "alpha_req", "stable_rank", "ne_sum", "self_clustering", "ripser"
]]


def objective(trial):
    # print(f'dealing with metric {metric_name}')
    torch.cuda.empty_cache()
    gc.collect()

    params = define_search_space(trial)

    # === Dataset ===
    data_module = create_datasets(train_dict, valid_dict, params, source_features)

    # === Encoders ===
    sourceA_encoder = TrxEncoder(
        embeddings_noise=0.003,
        linear_projection_size=64,
        embeddings={
            "mcc_code": {"in": mcc_code_in, "out": 32},
            "term_id": {"in": term_id_in, "out": 32},
        },
    )

    sourceB_encoder = TrxEncoder(
        embeddings_noise=0.003,
        linear_projection_size=64,
        embeddings={
            "tr_type": {"in": tr_type_in, "out": 32},
        },
        numeric_values={"amount": "identity"},
    )

    seq_encoder = MultiModalSortTimeSeqEncoderContainer(
        trx_encoders={"sourceA": sourceA_encoder, "sourceB": sourceB_encoder},
        input_size=64,
        hidden_size=params["hidden_size"],
        seq_encoder_cls=RnnEncoder,
        type="gru",
    )

    model = CoLESModule(
        seq_encoder=seq_encoder,
        optimizer_partial=partial(torch.optim.Adam, lr=params["learning_rate"]),
        lr_scheduler_partial=partial(torch.optim.lr_scheduler.StepLR, step_size=10, gamma=0.5),
    )

    # === Callbacks ===
    early_stopping_callback = EarlyStopping(
        monitor=f"valid/recall_top_k", patience=5, mode="max", verbose=True
    )

    checkpoint_callback = ModelCheckpoint(
        dirpath=checkpoints_path,
        filename=f"model_optuna_trial_{trial.number}_epoch={{epoch:02d}}",
        save_top_k=-1,
        monitor="valid/recall_top_k",
        mode="max",
    )

    trainer = pl.Trainer(
        callbacks=[checkpoint_callback, early_stopping_callback, custom_logger],
        default_root_dir=checkpoints_path,
        check_val_every_n_epoch=1,
        max_epochs= 1, # num_epochs,
        accelerator="gpu",
        devices=1,
        enable_progress_bar=True,
        precision=16,
    )

    trainer.fit(model, datamodule=data_module)

    early_stop_epoch = custom_logger.early_stopping_epoch or num_epochs

    # === Evaluate checkpoints ===
    checkpoint_files = sorted(
        glob.glob(f"{checkpoints_path}/model_optuna_trial_{trial.number}_epoch=*.ckpt")
    )

    best_acc = float("-inf")

    for i, checkpoint in enumerate(checkpoint_files):
        model = CoLESModule.load_from_checkpoint(checkpoint, seq_encoder=seq_encoder)
        metrics, times, acc = evaluate_model(model, trainer, inf_train_loader)

        metrics_flattened = {f"metric_{k}": round(v, 4) for k, v in metrics.items()}
        times_flattened = {f"time_{k}": round(v, 4) for k, v in times.items()}
        trainer = pl.Trainer(accelerator="gpu", devices=1)

        val_metrics = trainer.validate(model=model, datamodule=data_module)
        recall_top_k = val_metrics[0].get("valid/recall_top_k", None)
        result = {
            **params,
            "checkpoint": checkpoint,
            "epoch_num": i,
            "accuracy": acc,
            "topk_accuracy": recall_top_k,
            **metrics_flattened,
            **times_flattened,
            "early_stop_epoch": early_stop_epoch,
        }

        results_df = pd.DataFrame([result], columns=optuna_columns)

        if not os.path.exists(output_csv):
            pd.DataFrame(columns=optuna_columns).to_csv(output_csv, index=False, header=True)
        results_df.to_csv(output_csv, mode="a", header=False, index=False)

        # current_metric_value = metrics.get(metric_name, float("-inf"))
        best_acc = max(best_acc, acc)

        del model, result, metrics
        torch.cuda.empty_cache()
        gc.collect()

    for ckpt in checkpoint_files:
        os.remove(ckpt)

    return best_acc


In [None]:
import optuna
import pandas as pd
import os
from time import time
from functools import partial

# Параметры для Optuna
num_trials = 10
cur_time = time()

# Путь к файлу для сохранения лучших результатов
best_trials_csv = "optuna_best_trials_accuracy.csv"

# Если файла ещё нет, создаём его
if not os.path.exists(best_trials_csv):
    pd.DataFrame(columns=["value", *fixed_params.keys()]).to_csv(best_trials_csv, index=False)

# Теперь оптимизируем только по accuracy
study = optuna.create_study(direction="maximize")  # Оптимизируем именно accuracy!
study.optimize(objective, n_trials=num_trials)     # objective теперь должна возвращать 1-accuracy или -accuracy

# Достаем лучший результат
best_trial = study.best_trial
best_result = {
    "value": best_trial.value,
    **best_trial.params
}

# Сохраняем лучший результат
df_best = pd.DataFrame([best_result])
df_best.to_csv(best_trials_csv, mode="a", header=False, index=False)

# Логи
print(f"✅ Optimization completed (direction: maximize)")
print(f"⏱️ Time passed: {time() - cur_time:.2f} sec")
print(f"🥇 Best trial value: {best_trial.value}")
print(f"📊 Params: {best_trial.params}")


## Eval model with best hyperparams

In [None]:
input_csv = "/kaggle/input/gender-tr-best-params/gender_tr_optuna_best_params.csv"
best_trials_df = pd.read_csv(input_csv)

In [None]:
best_trials_df.reset_index(inplace=True)
best_trials_df.rename(columns={"index": "metric", "metric":"value",
                              "value":"batch_size", "batch_size":"learning_rate", "learning_rate":"hidden_size"}, inplace=True)

In [None]:
best_trials_df.rename(columns={"cnt_min": "embedding_dim", "embedding_dim":"cnt_min",
                              "category_embedding_dim":"cnt_max", "cnt_max":"category_embedding_dim"}, inplace=True)

In [None]:
best_trials_df

In [None]:
! rm -rf /kaggle/working/checkpoints

In [None]:
! rm /kaggle/working/optuna_best_metrics_eval.csv

In [None]:
import pandas as pd
import torch
import gc
import os
import glob
from functools import partial
from time import time
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from itertools import islice



checkpoints_path = "checkpoints"
os.makedirs(checkpoints_path, exist_ok=True)



columns = list(best_trials_df.columns) + [
    "checkpoint", "epoch_num", "accuracy", "early_stop_epoch"
] + [f"metric_{m}" for m in metric_names] + [f"time_{m}" for m in metric_names]

for idx, row in islice(best_trials_df.iterrows(), 2, None):
    metric_name = row["metric"]
    print(f"\n=== Processing best params for metric: {metric_name} ===")
    output_csv = f"optuna_best_metrics_eval_{metric_name}.csv"

    print(row)
    
    
    params = {
        "batch_size": int(row["batch_size"]),
        "learning_rate": float(row["learning_rate"]),
        "split_count": int(row["split_count"]),
        "cnt_min": int(row["cnt_min"]),
        "cnt_max": int(row["cnt_max"]),
        "embedding_dim": int(row["embedding_dim"]),
        "category_embedding_dim": int(row["category_embedding_dim"]),
        "hidden_size": int(row["hidden_size"]),  
    }

    
    train_loader = create_datasets(train_dict, valid_dict, params, source_features)

    
    sourceA_encoder = TrxEncoder(
        embeddings={"mcc_code": {"in": mcc_code_in, "out": 32}, "term_id": {"in": term_id_in, "out": 32}},
        embeddings_noise=0.003,
        linear_projection_size=64,
    )
    sourceB_encoder = TrxEncoder(
        embeddings={"tr_type": {"in": tr_type_in, "out": 32}},
        numeric_values={"amount": "identity"},
        embeddings_noise=0.003,
        linear_projection_size=64,
    )

    seq_encoder = MultiModalSortTimeSeqEncoderContainer(
        trx_encoders={"sourceA": sourceA_encoder, "sourceB": sourceB_encoder},
        input_size=64,
        hidden_size=params["hidden_size"],
        seq_encoder_cls=RnnEncoder,
        type="gru",
    )

    model = CoLESModule(
        seq_encoder=seq_encoder,
        optimizer_partial=partial(torch.optim.Adam, lr=params["learning_rate"]),
        lr_scheduler_partial=partial(torch.optim.lr_scheduler.StepLR, step_size=10, gamma=0.5),
    )

    early_stopping_callback = EarlyStopping(
        monitor="loss", patience=5, mode="min", verbose=True
    )
    checkpoint_callback = ModelCheckpoint(
        dirpath=checkpoints_path,
        filename=f"best_{metric_name}_trial_{idx}_epoch={{epoch:02d}}",
        save_top_k=-1,
        every_n_epochs=1,
    )

    trainer = Trainer(
        callbacks=[checkpoint_callback, early_stopping_callback],
        default_root_dir=checkpoints_path,
        check_val_every_n_epoch=1,
        max_epochs=30,
        accelerator="gpu",
        devices=1,
        enable_progress_bar=True,
        precision=16
    )

    
    trainer.fit(model, train_loader)
    early_stop_epoch = getattr(trainer.logger, "early_stopping_epoch", None) or num_epochs

    
    checkpoint_files = sorted(
        glob.glob(f"{checkpoints_path}/best_{metric_name}_trial_{idx}_epoch=*.ckpt")
    )
    model.cpu()
    del model
    torch.cuda.empty_cache()

    for i, checkpoint in enumerate(checkpoint_files):
        print(f"Evaluating checkpoint #{i}")
        model = CoLESModule.load_from_checkpoint(checkpoint, seq_encoder=seq_encoder)
        metrics, times, accuracy = evaluate_model(model, trainer)

        row_result = {
            **params,
            "metric": metric_name,
            "checkpoint": checkpoint,
            "epoch_num": i,
            "accuracy": accuracy,
            "early_stop_epoch": early_stop_epoch,
            **{f"metric_{k}": round(v, 4) for k, v in metrics.items()},
            **{f"time_{k}": round(v, 4) for k, v in times.items()}
        }

        
        result_df = pd.DataFrame([row_result], columns=columns)
        if not os.path.exists(output_csv):
            pd.DataFrame(columns=columns).to_csv(output_csv, index=False)
        result_df.to_csv(output_csv, mode="a", index=False, header=False)

        del model, result_df
        torch.cuda.empty_cache()
        gc.collect()

    
    for ckpt in checkpoint_files:
        os.remove(ckpt)
    del trainer, train_loader, seq_encoder
    torch.cuda.empty_cache()
    gc.collect()

print("✅ Evaluation of best params complete.")