<a href="https://colab.research.google.com/github/Yanina-Kutovaya/GNN/blob/main/notebooks/catboost_stratified_split.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Функция stratified_train_test_validation_split для загрузки данных из PostgreSQL в Сatboost для данных, которые не помещаются в оперативную память, но помещаются на диск.

В классе с unit-тестами ```TestStratifiedSplit``` реализованные тесты:

- ```test_invalid_train_val_size``` - Обработка ошибочных значений ```train_size``` и ```val_size```
- ```test_sql_queries_built_correctly``` - Корректное формирование SQL-запросов
- ```test_pool_objects_returned``` - Возвращаются ли объекты ```Сatboost Pool```
- ```test_class_weights_calculated_correctly``` - Верный расчет весов классов
- ```test_test_data_saved_to_file``` - Сохранение тестовой выборки в CSV

In [1]:
!pip install -q catboost

In [2]:
import logging
import os
import psycopg2
import pandas as pd
import numpy as np
from catboost import Pool

# Настройка логгера
logger = logging.getLogger(__name__)
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler()
    ]
)

def stratified_train_test_validation_split(
    connection_params,
    table_name,
    target_column,
    id_column='id',
    train_size=0.6,
    val_size=0.2,
    batch_size=1000,
    seed=42,
    output_test_file=None
):
    """
    Разделяет данные из PostgreSQL на train/validation/test с сохранением пропорций классов.
    Тестовые данные сохраняются в CSV-файл для инференса.
    Возвращает: train_pool, val_pool, class_weights
    """
    logger.info("Начало разбиения данных")

    test_size = 1.0 - train_size - val_size
    if test_size < 0:
        logger.error("Сумма train_size и val_size больше 1")
        raise ValueError("Сумма train_size и val_size должна быть <= 1.")

    conn = None
    cursor = None

    try:
        # Подключение к БД
        conn = psycopg2.connect(**connection_params)
        cursor = conn.cursor()
        logger.info("Подключено к базе данных")

        # Получение списка столбцов
        cursor.execute(f"SELECT column_name FROM information_schema.columns WHERE table_name='{table_name}'")
        columns = [row[0] for row in cursor.fetchall()]
        feature_columns = [col for col in columns if col != target_column]
        logger.info(f"Определены {len(feature_columns)} фичей: {feature_columns}")

        # Статистика классов
        cursor.execute(f"SELECT {target_column}, COUNT(*) FROM {table_name} GROUP BY {target_column}")
        class_counts = dict(cursor.fetchall())
        logger.info(f"Количество записей по классам: {class_counts}")

        # Веса классов
        total_samples = sum(class_counts.values())
        class_weights = {
            cls: total_samples / (len(class_counts) * count)
            for cls, count in class_counts.items()
        }
        logger.info(f"Вычислены веса классов: {class_weights}")

        # Построение SQL-запросов
        np.random.seed(seed)

        def build_query(split):
            queries = []
            for cls, count in class_counts.items():
                total_cls = count
                if split == 'train':
                    start = 0
                    end = int(total_cls * train_size)
                elif split == 'val':
                    start = int(total_cls * train_size)
                    end = start + int(total_cls * val_size)
                else:
                    start = int(total_cls * (train_size + val_size))
                    end = total_cls
                query = f"""
                    SELECT * FROM (
                        SELECT *,
                        ABS(HASHTEXT({id_column}::TEXT || '{seed}')) % 10000 AS hash_val
                        FROM {table_name}
                        WHERE {target_column} = '{cls}'
                    ) AS sub
                    WHERE hash_val BETWEEN {start} AND {end - 1}
                """
                queries.append(query)
            return " UNION ALL ".join(queries)

        train_query = build_query('train')
        val_query = build_query('val')
        test_query = build_query('test')

        # Класс итератора
        class PostgreSQLPoolIterator:
            def __init__(self, connection_params, query, feature_cols, target_col, class_weights, batch_size):
                self.connection_params = connection_params
                self.query = query
                self.feature_cols = feature_cols
                self.target_col = target_col
                self.class_weights = class_weights
                self.batch_size = batch_size
                self.conn = None
                self.cursor = None
                self.colnames = None

            def __iter__(self):
                self.conn = psycopg2.connect(**self.connection_params)
                self.cursor = self.conn.cursor()
                self.cursor.execute(self.query)
                self.colnames = [desc[0] for desc in self.cursor.description]
                return self

            def __next__(self):
                records = self.cursor.fetchmany(self.batch_size)
                if not records:
                    self.cursor.close()
                    self.conn.close()
                    raise StopIteration
                df = pd.DataFrame(records, columns=self.colnames)
                features = df[self.feature_cols].values
                target = df[self.target_col].values
                weights = df[self.target_col].map(self.class_weights).values
                return features, target, weights

            def get_all_data(self):
                """Загружает все данные в numpy-массивы"""
                all_features = []
                all_target = []
                all_weights = []
                for features, target, weights in self:
                    all_features.append(features)
                    all_target.append(target)
                    all_weights.append(weights)

                if len(all_features) == 0:
                    logger.warning("Нет данных для обучения")
                    feature_shape = (1, len(self.feature_cols))
                    return (
                        np.zeros(feature_shape, dtype=np.float32),
                        np.array([0], dtype=np.int64),
                        np.array([1.0], dtype=np.float32)
                    )

                return (
                    np.vstack(all_features),
                    np.concatenate(all_target),
                    np.concatenate(all_weights)
                )

        # Создание итераторов
        train_iter = PostgreSQLPoolIterator(
            connection_params, train_query, feature_columns, target_column, class_weights, batch_size
        )
        val_iter = PostgreSQLPoolIterator(
            connection_params, val_query, feature_columns, target_column, class_weights, batch_size
        )

        train_data, train_label, train_weight = train_iter.get_all_data()
        val_data, val_label, val_weight = val_iter.get_all_data()

        train_pool = Pool(
            data=train_data,
            label=train_label,
            weight=train_weight,
            feature_names=feature_columns
        )
        val_pool = Pool(
            data=val_data,
            label=val_label,
            weight=val_weight,
            feature_names=feature_columns
        )

        # Сохранение тестовых данных в CSV
        if output_test_file:
            logger.info(f"Сохранение тестовых данных в файл: {output_test_file}")
            test_iter = PostgreSQLPoolIterator(
                connection_params, test_query, feature_columns, target_column, class_weights, batch_size
            )
            with open(output_test_file, 'w') as f:
                header = ','.join(feature_columns + [target_column, 'weight'])
                f.write(header + '\n')
                for batch in test_iter:
                    features, target, weights = batch
                    df = pd.DataFrame(features, columns=feature_columns)
                    df[target_column] = target
                    df['weight'] = weights
                    df.to_csv(f, index=False, header=False, mode='a')

        logger.info("Разделение данных успешно завершено")
        return train_pool, val_pool, class_weights

    except Exception as e:
        logger.error("Ошибка при разделении данных", exc_info=True)
        raise
    finally:
        if cursor:
            cursor.close()
        if conn:
            conn.close()
        logger.info("Соединение с БД закрыто")


import unittest
from unittest.mock import patch, MagicMock, Mock
import os
import tempfile
import numpy as np
import pandas as pd
from catboost import Pool
import logging


class TestStratifiedSplit(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        """Настройка логгирования перед запуском всех тестов"""
        cls.original_log_level = logging.getLogger().level
        logging.disable(logging.CRITICAL)  # Отключаем логи для чистоты вывода

    @classmethod
    def tearDownClass(cls):
        """Восстанавливаем уровень логгирования после всех тестов"""
        logging.disable(cls.original_log_level)

    def setUp(self):
        """Настройка мока для всех тестов"""
        self.mock_connect_patcher = patch('psycopg2.connect')
        self.mock_connect = self.mock_connect_patcher.start()
        self.mock_cursor = Mock()
        self.mock_connect.return_value.cursor.return_value = self.mock_cursor

        self.columns = [('id',), ('feature1',), ('feature2',), ('target',)]
        self.class_counts = [('class1', 60), ('class2', 40)]
        self.feature_columns = ['feature1', 'feature2']
        self.executed_queries = []

        def execute_side_effect(query, *args, **kwargs):
            self.executed_queries.append(query)
            return None

        self.mock_cursor.execute.side_effect = execute_side_effect

        self.data_chunks = [
            [(1, 1.1, 2.1, 'class1')] * 10,
            [(1, 1.1, 2.1, 'class1')] * 10,
            [(2, 1.2, 2.2, 'class2')] * 10,
            [(2, 1.2, 2.2, 'class2')] * 10,
        ]
        self._call_count = 0

        def fetchmany_side_effect(batch_size):
            if self._call_count >= len(self.data_chunks):
                return []
            result = self.data_chunks[self._call_count]
            self._call_count += 1
            return result

        self.mock_cursor.fetchmany.side_effect = fetchmany_side_effect

        self.mock_cursor.fetchall.side_effect = [
            self.columns,
            self.class_counts,
        ]

        self.mock_cursor.description = [
            ('id', None), ('feature1', None), ('feature2', None), ('target', None)
        ]

        self.connection_params = {
            'dbname': 'test_db',
            'user': 'user',
            'password': 'pass',
            'host': 'localhost'
        }

    def tearDown(self):
        """Очистка после тестов"""
        self.mock_connect_patcher.stop()

    def test_invalid_train_val_size_raises_value_error(self):
        """
        Проверяет, что функция вызывает ValueError при некорректных размерах выборки.
        """
        with self.assertRaises(ValueError) as context:
            stratified_train_test_validation_split(
                self.connection_params,
                'test_table',
                'target',
                train_size=0.7,
                val_size=0.4
            )
        expected_msg = "Сумма train_size и val_size должна быть <= 1."
        self.assertEqual(str(context.exception), expected_msg)

    def test_sql_queries_built_correctly(self):
        """
        Проверяет, что SQL-запросы формируются корректно.
        """
        _, _, _ = stratified_train_test_validation_split(
            self.connection_params,
            'test_table',
            'target',
            id_column='id',
            train_size=0.6,
            val_size=0.2,
            seed=42
        )

        column_query = (
            "SELECT column_name FROM information_schema.columns WHERE table_name='test_table'"
        )
        count_query = (
            "SELECT target, COUNT(*) FROM test_table GROUP BY target"
        )

        self.assertIn(column_query, self.executed_queries)
        self.assertIn(count_query, self.executed_queries)

        for query in self.executed_queries:
            if 'BETWEEN' in query:
                self.assertIn("ABS(HASHTEXT", query)
                self.assertIn("BETWEEN", query)

    def test_pool_objects_returned_with_valid_data(self):
        """
        Проверяет, что возвращаются объекты Pool при корректных данных.
        """
        train_pool, val_pool, class_weights = stratified_train_test_validation_split(
            self.connection_params,
            'test_table',
            'target',
            id_column='id',
            train_size=0.6,
            val_size=0.2,
            seed=42
        )

        self.assertIsInstance(train_pool, Pool)
        self.assertIsInstance(val_pool, Pool)
        self.assertGreater(len(class_weights), 0)

    def test_class_weights_calculated_correctly(self):
        """
        Проверяет вычисление весов классов.
        """
        _, _, class_weights = stratified_train_test_validation_split(
            self.connection_params,
            'test_table',
            'target',
            id_column='id',
            train_size=0.6,
            val_size=0.2,
            seed=42
        )

        total_samples = sum(count for _, count in self.class_counts)
        expected_weights = {
            cls: total_samples / (len(self.class_counts) * count)
            for cls, count in self.class_counts
        }

        for cls in class_weights:
            self.assertAlmostEqual(class_weights[cls], expected_weights[cls], delta=1e-6)

    def test_test_data_saved_to_file_correctly(self):
        """
        Проверяет сохранение тестовых данных в CSV файл.
        """
        with tempfile.NamedTemporaryFile(delete=False, suffix='.csv') as tmpfile:
            filename = tmpfile.name

        try:
            _, _, _ = stratified_train_test_validation_split(
                self.connection_params,
                'test_table',
                'target',
                id_column='id',
                train_size=0.6,
                val_size=0.2,
                seed=42,
                output_test_file=filename
            )

            df = pd.read_csv(filename)
            self.assertTrue(len(df.columns) > 0)
            self.assertIn('target', df.columns)
            self.assertIn('weight', df.columns)
            self.assertTrue(all(col in df.columns for col in self.feature_columns))
        finally:
            if os.path.exists(filename):
                os.remove(filename)


if __name__ == '__main__':
    unittest.main(argv=[''], exit=False)

.....
----------------------------------------------------------------------
Ran 5 tests in 0.081s

OK
