<a href="https://colab.research.google.com/github/Yanina-Kutovaya/GNN/blob/main/notebooks/Mock_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Проверка работы Mock-тестов для BitcoinGraphDataset

In [1]:
!pip install -q torch torch-geometric pandas psycopg2-binary

In [2]:
import unittest
from unittest.mock import patch, MagicMock
import pandas as pd
import torch
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
import psycopg2

# === BitcoinGraphDataset Класс ===
class BitcoinGraphDataset(torch.utils.data.Dataset):
    def __init__(self, db_config, batch_size=32):
        self.db_config = db_config
        self.batch_size = batch_size
        self.conn = None
        self.cursor = None
        try:
            self.conn = psycopg2.connect(**db_config)
            self.cursor = self.conn.cursor()
            self.cursor.execute("SELECT COUNT(*) FROM node_attributes")
            result = self.cursor.fetchone()
            self.total_nodes = result[0] if result else 0
        except Exception:
            # Игнорируем ошибки подключения, но инициализируем базовые поля
            self.total_nodes = 0

    def __len__(self):
        return (self.total_nodes + self.batch_size - 1) // self.batch_size

    def __getitem__(self, idx):
        if self.conn is None or self.cursor is None:
            return Data(
                x=torch.tensor([], dtype=torch.float),
                edge_index=torch.tensor([], dtype=torch.long).view(2, -1),
                edge_attr=torch.tensor([], dtype=torch.float).view(-1, 1),
                y=torch.tensor([], dtype=torch.long)
            )

        offset = idx * self.batch_size
        query_nodes = f"""
        SELECT alias, label, degree, total_received
        FROM node_attributes
        LIMIT {self.batch_size} OFFSET {offset}
        """
        try:
            nodes_df = pd.read_sql(query_nodes, self.conn)
        except Exception:
            return Data(
                x=torch.tensor([], dtype=torch.float),
                edge_index=torch.tensor([], dtype=torch.long).view(2, -1),
                edge_attr=torch.tensor([], dtype=torch.float).view(-1, 1),
                y=torch.tensor([], dtype=torch.long)
            )

        if nodes_df.empty:
            return Data(
                x=torch.tensor([], dtype=torch.float),
                edge_index=torch.tensor([], dtype=torch.long).view(2, -1),
                edge_attr=torch.tensor([], dtype=torch.float).view(-1, 1),
                y=torch.tensor([], dtype=torch.long)
            )

        aliases = nodes_df['alias'].tolist()

        query_edges = f"""
        SELECT a, b, total_sent
        FROM edge_attributes
        WHERE a IN ({','.join(map(str, aliases))}) OR b IN ({','.join(map(str, aliases))})
        """
        try:
            edges_df = pd.read_sql(query_edges, self.conn)
        except Exception:
            return Data(
                x=torch.tensor([], dtype=torch.float),
                edge_index=torch.tensor([], dtype=torch.long).view(2, -1),
                edge_attr=torch.tensor([], dtype=torch.float).view(-1, 1),
                y=torch.tensor([], dtype=torch.long)
            )

        alias_to_idx = {alias: i for i, alias in enumerate(aliases)}
        edges = []
        edge_attrs = []
        for _, row in edges_df.iterrows():
            a_idx = alias_to_idx.get(row['a'])
            b_idx = alias_to_idx.get(row['b'])
            if a_idx is not None and b_idx is not None:
                edges.append([a_idx, b_idx])
                edge_attrs.append(row['total_sent'])

        edge_index = torch.tensor(edges, dtype=torch.long).T if edges else torch.zeros((2, 0), dtype=torch.long)
        edge_attr = torch.tensor(edge_attrs, dtype=torch.float).view(-1, 1) if edge_attrs else torch.zeros((0, 1), dtype=torch.float)
        x = torch.tensor(nodes_df[['degree', 'total_received']].fillna(0).values, dtype=torch.float)
        y = torch.tensor(nodes_df['label'].fillna(0).values, dtype=torch.long)

        return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)


# === Тестовый класс ===
class TestBitcoinGraphDataset(unittest.TestCase):
    def setUp(self):
        # Конфигурация БД
        self.db_config = {
            'dbname': 'test_db',
            'user': 'test_user',
            'password': 'test_pass',
            'host': 'localhost',
            'port': 5432
        }

        # Маленький датасет (2 узла)
        self.nodes_data_small = pd.DataFrame({
            'alias': [1, 2],
            'label': [0, 1],
            'degree': [2, 1],
            'total_received': [100.0, 200.0]
        })

        # Большой датасет (65 узлов)
        self.nodes_data_large = pd.DataFrame({
            'alias': list(range(1, 66)),
            'label': [i % 2 for i in range(65)],
            'degree': [2] * 65,
            'total_received': [100.0 + i for i in range(65)]
        })

        # Рёбра
        self.edges_data = pd.DataFrame({
            'a': [1, 2],
            'b': [2, 1],
            'total_sent': [50.0, 30.0]
        })

    # ✅ Вспомогательная функция: настройка мока БД
    def _setup_mock(self, mock_connect, total_nodes=2, nodes_data=None):
        mock_conn = MagicMock()
        mock_cursor = MagicMock()
        mock_cursor.fetchone.return_value = [total_nodes]
        mock_conn.cursor.return_value = mock_cursor
        mock_connect.return_value = mock_conn
        nodes_data = nodes_data if nodes_data is not None else self.nodes_data_small
        return mock_conn, mock_cursor, nodes_data

    # ✅ Тест: корректный граф с данными
    @patch('psycopg2.connect')
    @patch('pandas.read_sql')
    def test_getitem_returns_correct_data(self, mock_read_sql, mock_connect):
        mock_conn, mock_cursor, nodes_data = self._setup_mock(mock_connect, 2)
        def side_effect(query, *args, **kwargs):
            if 'SELECT alias' in query:
                return nodes_data
            elif 'SELECT a, b' in query:
                return self.edges_data
            return pd.DataFrame()
        mock_read_sql.side_effect = side_effect

        dataset = BitcoinGraphDataset(self.db_config, batch_size=2)
        batch = dataset[0]

        self.assertIsInstance(batch, Data)
        self.assertEqual(batch.x.shape, (2, 2))
        self.assertEqual(batch.edge_index.shape, (2, 2))
        self.assertEqual(batch.edge_attr.shape, (2, 1))
        self.assertEqual(batch.y.shape, (2,))

    # ✅ Тест: пустые рёбра
    @patch('psycopg2.connect')
    @patch('pandas.read_sql')
    def test_getitem_returns_empty_graph_when_no_edges(self, mock_read_sql, mock_connect):
        mock_conn, mock_cursor, nodes_data = self._setup_mock(mock_connect, 2)
        def side_effect(query, *args, **kwargs):
            if 'SELECT alias' in query:
                return nodes_data
            elif 'SELECT a, b' in query:
                return pd.DataFrame(columns=['a', 'b', 'total_sent'])
            return pd.DataFrame()
        mock_read_sql.side_effect = side_effect

        dataset = BitcoinGraphDataset(self.db_config, batch_size=2)
        batch = dataset[0]

        self.assertIsInstance(batch, Data)
        self.assertEqual(batch.x.shape, (2, 2))
        self.assertEqual(batch.edge_index.shape, (2, 0))
        self.assertEqual(batch.edge_attr.shape, (0, 1))
        self.assertEqual(batch.y.shape, (2,))

    # ✅ Тест: корректное вычисление длины
    @patch('psycopg2.connect')
    def test_len_returns_correct_value(self, mock_connect):
        mock_conn, mock_cursor, nodes_data = self._setup_mock(mock_connect, 65)
        dataset = BitcoinGraphDataset(self.db_config, batch_size=32)
        self.assertEqual(len(dataset), 3)

    # ✅ Тест: совместимость с моделью
    @patch('psycopg2.connect')
    @patch('pandas.read_sql')
    def test_model_compatibility(self, mock_read_sql, mock_connect):
        mock_conn, mock_cursor, nodes_data = self._setup_mock(mock_connect, 2)
        def side_effect(query, *args, **kwargs):
            if 'SELECT alias' in query:
                return nodes_data
            elif 'SELECT a, b' in query:
                return self.edges_data
            return pd.DataFrame()
        mock_read_sql.side_effect = side_effect

        dataset = BitcoinGraphDataset(self.db_config, batch_size=2)
        batch = dataset[0]

        class DummyGCN(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = GCNConv(2, 2)
            def forward(self, x, edge_index):
                return self.conv(x, edge_index)

        model = DummyGCN()
        with torch.no_grad():
            out = model(batch.x, batch.edge_index)
        self.assertEqual(out.shape, (2, 2))

    # ✅ Тест: минимальный набор данных
    def test_minimal_dataset(self):
        minimal_nodes = pd.DataFrame({
            'alias': [1],
            'label': [0],
            'degree': [1],
            'total_received': [100.0]
        })
        minimal_edges = pd.DataFrame({
            'a': [1],
            'b': [1],
            'total_sent': [50.0]
        })

        @patch('psycopg2.connect')
        @patch('pandas.read_sql')
        def run_test(mock_read_sql, mock_connect):
            mock_conn, mock_cursor, nodes_data = self._setup_mock(mock_connect, 1)
            def side_effect(query, *args, **kwargs):
                if 'SELECT alias' in query:
                    return minimal_nodes
                elif 'SELECT a, b' in query:
                    return minimal_edges
                return pd.DataFrame()
            mock_read_sql.side_effect = side_effect

            dataset = BitcoinGraphDataset(self.db_config, batch_size=1)
            batch = dataset[0]
            self.assertEqual(batch.x.shape, (1, 2))
            self.assertEqual(batch.edge_index.shape, (2, 1))
            self.assertEqual(batch.edge_attr.shape, (1, 1))
        run_test()

    # ✅ Тест: обработка ошибок БД
    @patch('psycopg2.connect')
    @patch('pandas.read_sql')
    def test_error_handling_on_sql_failure(self, mock_read_sql, mock_connect):
        mock_conn = MagicMock()
        mock_cursor = MagicMock()
        mock_cursor.fetchone.side_effect = Exception("Database connection failed")
        mock_conn.cursor.return_value = mock_cursor
        mock_connect.return_value = mock_conn
        mock_read_sql.side_effect = Exception("Invalid SQL query")

        dataset = BitcoinGraphDataset(self.db_config, batch_size=2)
        batch = dataset[0]

        self.assertIsInstance(batch, Data)
        self.assertTrue(batch.x.numel() == 0)
        self.assertTrue(batch.edge_index.numel() == 0)
        self.assertTrue(batch.edge_attr.numel() == 0)
        self.assertTrue(batch.y.numel() == 0)

    # ✅ Тест: формат SQL-запросов
    @patch('psycopg2.connect')
    @patch('pandas.read_sql')
    def test_sql_query_format(self, mock_read_sql, mock_connect):
        mock_conn, mock_cursor, nodes_data = self._setup_mock(mock_connect, 2)
        def side_effect(query, *args, **kwargs):
            if 'SELECT alias' in query:
                self.assertRegex(query, r"LIMIT\s+2\s+OFFSET\s+0")
            return nodes_data if "SELECT alias" in query else self.edges_data
        mock_read_sql.side_effect = side_effect

        dataset = BitcoinGraphDataset(self.db_config, batch_size=2)
        dataset[0]

    # ✅ Тест: корректность батчей
    @patch('psycopg2.connect')
    @patch('pandas.read_sql')
    def test_batch_consistency(self, mock_read_sql, mock_connect):
        mock_conn, mock_cursor, nodes_data = self._setup_mock(mock_connect, 65, self.nodes_data_large)
        def side_effect(query, *args, **kwargs):
            if 'SELECT alias' in query:
                limit = int(query.split("LIMIT")[1].split()[0])
                offset = int(query.split("OFFSET")[1].strip())
                return self.nodes_data_large.iloc[offset:offset + limit].copy()
            elif 'SELECT a, b' in query:
                return self.edges_data
            return pd.DataFrame()
        mock_read_sql.side_effect = side_effect

        dataset = BitcoinGraphDataset(self.db_config, batch_size=32)
        for i in range(3):
            batch = dataset[i]
            expected_size = min(32, 65 - i * 32)
            self.assertEqual(batch.x.shape[0], expected_size)


In [3]:
# === Запуск тестов ===
if __name__ == '__main__':
    unittest.main(argv=[''], exit=False)

........
----------------------------------------------------------------------
Ran 8 tests in 0.106s

OK
