In [10]:
import numpy as np
import pandas as pd
import ctgan
from ctgan import CTGAN

In [11]:
import dataclasses


@dataclasses.dataclass
class Card:
    name: str
    model: CTGAN
    target: str
    schedule_path: str
    real_data: pd.DataFrame
    synt_data: pd.DataFrame
    jensen_shannon_divergence: pd.DataFrame
    real_score: float
    synt_score: float

    def __init__(self, name: str, target: str, dataset_path: str, model_path: str, schedule_path: str):
        self.name = name
        self.target = target
        self.model = ctgan.CTGAN.load(model_path)
        self.real_data = pd.read_csv(dataset_path)
        self.synt_data = self.model.sample(len(self.real_data))
        self.schedule_path = schedule_path
        self.jensen_shannon_divergence = pd.DataFrame()
        self.real_score = 0.0
        self.sint_score = 0.0


In [12]:
from pathlib import Path


def load_cards_from_params(params_dir: str) -> dict[str, Card]:
    params_path = Path(params_dir)
    cards: dict[str, Card] = {}
    if not params_path.exists() or not params_path.is_dir():
        return cards

    for file in params_path.iterdir():
        if not file.is_file():
            continue

        lines = [ln.strip() for ln in file.read_text(encoding='utf-8').splitlines() if ln.strip()]
        if not lines:
            continue
        name = lines[0]
        target = lines[1] if len(lines) > 1 else ""
        dataset_path = lines[2] if len(lines) > 2 else ""
        model_path = lines[3] if len(lines) > 3 else ""
        schedule_path = lines[4] if len(lines) > 4 else ""
        card = Card(name=name, target=target, dataset_path=dataset_path,
                    model_path=model_path, schedule_path=schedule_path)
        cards[name] = card

    return cards

In [13]:
cards = load_cards_from_params('./models/params/')

In [14]:
import warnings
from sklearn.exceptions import ConvergenceWarning
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# Подавляем предупреждения для чистоты вывода
warnings.filterwarnings("ignore", category=ConvergenceWarning)
warnings.filterwarnings("ignore", category=FutureWarning)


def evaluate_card(card: Card) -> None:
    X_real = card.real_data.drop(columns=[card.target])
    y_real = card.real_data[card.target]
    X_synt = card.synt_data.drop(columns=[card.target])
    y_synt = card.synt_data[card.target]

    # Стратифицированный split для сохранения распределения классов
    X_train_real, X_test_real, y_train_real, y_test_real = train_test_split(
        X_real, y_real, test_size=0.5, random_state=42, stratify=y_real
    )
    X_train_synt, _, y_train_synt, _ = train_test_split(
        X_synt, y_synt, test_size=0.5, random_state=42, stratify=y_synt
    )

    # Pipeline: StandardScaler + LogisticRegression без устаревшего параметра multi_class
    clf_real = make_pipeline(
        StandardScaler(),
        LogisticRegression(solver='saga', max_iter=20000, random_state=42)
    )
    clf_real.fit(X_train_real, y_train_real)
    y_pred_real = clf_real.predict(X_test_real)
    real_score = accuracy_score(y_test_real, y_pred_real)

    clf_synt = make_pipeline(
        StandardScaler(),
        LogisticRegression(solver='saga', max_iter=20000, random_state=42)
    )
    clf_synt.fit(X_train_synt, y_train_synt)
    y_pred_synt = clf_synt.predict(X_test_real)
    synt_score = accuracy_score(y_test_real, y_pred_synt)

    card.real_score = real_score
    card.synt_score = synt_score


In [15]:
import numpy as np
import pandas as pd


def calculate_distributions(p_series: pd.Series, q_series: pd.Series):
    """
    Вычисляет распределения вероятностей для двух серий (колонок) pandas.
    Возвращает p, q на общем наборе уникальных значений.
    """
    # 1. Получаем массив уникальных элементов из обеих колонок
    all_values = pd.Index(p_series.unique()).union(q_series.unique())

    # 2. Считаем вероятности для каждого уникального элемента
    p_dist = p_series.value_counts(normalize=True).reindex(all_values, fill_value=0)
    q_dist = q_series.value_counts(normalize=True).reindex(all_values, fill_value=0)

    return p_dist, q_dist


def calculate_metrics(p_df: pd.DataFrame, q_df: pd.DataFrame, skip_col: list[str]) -> pd.DataFrame:
    """
    Вычисляет энтропию, KL и JSD дивергенции для каждой колонки.
    """
    metrics = []
    epsilon = 1e-10  # Малая константа для избежания деления на ноль

    for col in p_df.columns:
        if col not in q_df.columns:
            continue
        if col in skip_col:
            continue

        p, q = calculate_distributions(p_df[col], q_df[col])

        # --- Энтропия Шеннона для реальных данных H(P) ---
        # Используем только ненулевые вероятности, так как 0*log(0) = 0
        p_nonzero = p[p > 0]
        shannon_entropy = -np.sum(p_nonzero * np.log2(p_nonzero))

        # --- Дивергенция Кульбака-Лейблера D_KL(P || Q) ---
        # Добавляем epsilon к q, чтобы избежать log(0) или деления на 0
        q_smooth = q + epsilon
        kl_divergence = np.sum(p_nonzero * np.log2(p_nonzero / q_smooth[p_nonzero.index]))

        # --- Дивергенция Йенсена-Шеннона JSD(P || Q) ---
        m = 0.5 * (p + q)
        m_smooth = m + epsilon

        # D_KL(P || M)
        kl_p_m = np.sum(p_nonzero * np.log2(p_nonzero / m_smooth[p_nonzero.index]))

        # D_KL(Q || M)
        q_nonzero = q[q > 0]
        kl_q_m = np.sum(q_nonzero * np.log2(q_nonzero / m_smooth[q_nonzero.index]))

        jensen_shannon_divergence = 0.5 * kl_p_m + 0.5 * kl_q_m

        metrics.append({
            'column': col,
            'shannon_entropy': shannon_entropy,
            'kl_divergence': kl_divergence,
            'jensen_shannon_divergence': jensen_shannon_divergence
        })

    return pd.DataFrame(metrics)[['column', 'jensen_shannon_divergence']]

# Пример использования с вашим объектом Card
# for card in cards.values():
#     # Вычисляем метрики
#     metrics_df = calculate_metrics(card.real_data, card.synt_data)
#     # Сохраняем результат (например, в новый атрибут)
#     card.metrics = metrics_df
#     print(f"Метрики для {card.name}:")
#     print(card.metrics)
#     print("-" * 30)


In [16]:
for card in cards.values():
    evaluate_card(card)
    card.jensen_shannon_divergence = calculate_metrics(card.real_data, card.synt_data, skip_col=[card.target, 'Id','id','ID','identifier'])

In [18]:

# === HTML report generator (autonomous, single-file) ===
import base64, io, json, os, statistics
from pathlib import Path
from datetime import datetime
from html import escape


def _embed_image_base64(path: str) -> str:
    """
    Read an image by path and return a <img> tag with data URI.
    If file is missing or unreadable, returns an empty string.
    """
    if not path:
        return ""
    try:
        p = Path(path)
        if not p.exists():
            # try relative to notebook root
            p = Path.cwd() / path
        with open(p, "rb") as f:
            b64 = base64.b64encode(f.read()).decode("ascii")
        # Minimal type inference by extension
        ext = p.suffix.lower().lstrip(".")
        mime = {
            "png": "image/png", "jpg": "image/jpeg", "jpeg": "image/jpeg",
            "gif": "image/gif", "svg": "image/svg+xml", "webp": "image/webp"
        }.get(ext, "image/png")
        return f'<img class="schedule-img" src="data:{mime};base64,{b64}" alt="schedule">'
    except Exception:
        return ""


def _fmt(v, digits=4):
    try:
        if v is None: return "—"
        return f"{float(v):.{digits}f}"
    except Exception:
        return str(v)


def _safe_get_synt_score(card):
    # Falls back to a mistyped attribute if present to avoid breaking user's class.
    val = getattr(card, "synt_score", None)
    if val is None:
        val = getattr(card, "sint_score", None)
    return val


def create_cards_html(cards: dict, output_path: str = "cards_report.html", top_k: int = 30) -> str:
    """
    Создаёт полностью автономный HTML-файл с:
      • карточкой по каждому Card из `cards`
      • сводной таблицей для удобного сравнения
    Все стили/скрипты/картинки инлайн-бандлятся (base64), так что файл можно переслать на любой компьютер.

    Параметры:
      cards: dict[str, Card] — словарь с объектами Card
      output_path: куда сохранить HTML
      top_k: сколько строк метрик JSD показывать в карточке (по убыванию)
    Возвращает путь к созданному файлу.
    """
    # Собираем агрегаты для сравнения
    comparison = []
    for name, card in cards.items():
        n_rows = len(card.real_data) if hasattr(card, "real_data") else 0
        # features: все столбцы минус целевая
        try:
            n_features = max(0, card.real_data.shape[1] - 1)
        except Exception:
            n_features = 0
        real_score = getattr(card, "real_score", None)
        synt_score = _safe_get_synt_score(card)
        gap = (real_score - synt_score) if (real_score is not None and synt_score is not None) else None

        # средний JSD
        mean_jsd = None
        if hasattr(card, "jensen_shannon_divergence") and isinstance(card.jensen_shannon_divergence,
                                                                     type(getattr(card, "real_data", None))):
            try:
                jsd_col = card.jensen_shannon_divergence["jensen_shannon_divergence"]
                mean_jsd = float(jsd_col.mean())
            except Exception:
                mean_jsd = None

        comparison.append({
            "name": name,
            "target": getattr(card, "target", ""),
            "rows": n_rows,
            "features": n_features,
            "real_score": real_score,
            "synt_score": synt_score,
            "gap": gap,
            "mean_jsd": mean_jsd,
        })

    # HTML-шаблон
    style = """
    <style>
      :root {
        --bg: #0e1116;
        --panel: #161b22;
        --text: #e6edf3;
        --muted: #9aa6b2;
        --accent: #7aa2f7;
        --ok: #3fb950;
        --warn: #e3b341;
        --bad: #f85149;
        --border: #30363d;
      }
      * { box-sizing: border-box; }
      body {
        margin: 0; padding: 24px 24px 120px;
        font-family: -apple-system,BlinkMacSystemFont,Segoe UI,Roboto,Helvetica,Arial,Ubuntu, Cantarell, "Noto Sans", "Apple Color Emoji","Segoe UI Emoji";
        background: var(--bg); color: var(--text);
      }
      h1 { margin: 0 0 16px; font-size: 28px; }
      .subtitle { color: var(--muted); margin-bottom: 32px; }
      .grid {
        display: grid;
        grid-template-columns: repeat(auto-fill, minmax(340px, 1fr));
        gap: 16px;
      }
      .card {
        background: var(--panel);
        border: 1px solid var(--border);
        border-radius: 12px;
        padding: 16px;
      }
      .card h2 { margin: 0 0 8px; font-size: 20px; }
      .meta { color: var(--muted); font-size: 12px; margin-bottom: 8px; }
      .badges { display: flex; gap: 8px; flex-wrap: wrap; margin: 8px 0 12px; }
      .badge {
        padding: 4px 8px; border-radius: 999px; font-size: 12px; border: 1px solid var(--border);
        background: #0b1220; color: var(--text);
      }
      .badge.ok { border-color: rgba(63,185,80,.4); background: rgba(63,185,80,.08); }
      .badge.warn { border-color: rgba(227,179,65,.4); background: rgba(227,179,65,.08); }
      .badge.bad { border-color: rgba(248,81,73,.4); background: rgba(248,81,73,.08); }
      .schedule-img { display:block; max-width:100%; height:auto; border-radius:8px; border:1px solid var(--border); margin: 6px 0 10px; }
      details { margin-top: 8px; }
      details > summary {
        cursor: pointer; list-style: none; color: var(--accent);
        margin: 10px 0; user-select: none;
      }
      table { width: 100%; border-collapse: collapse; }
      thead th {
        text-align:left; font-weight:600; color: var(--muted);
        border-bottom:1px solid var(--border); padding:8px; cursor: pointer;
      }
      tbody td { padding: 8px; border-bottom: 1px dashed var(--border); font-size: 14px; }
      .cmp-table { margin: 6px 0 18px; background: var(--panel); border-radius: 12px; border:1px solid var(--border); overflow: hidden; }
      .footer { position: fixed; left:0; right:0; bottom:0; padding:10px 16px; background: linear-gradient(180deg, rgba(14,17,22,.0), rgba(14,17,22,.75) 20%); }
      .hint { color: var(--muted); font-size: 12px; }
      a { color: var(--accent); text-decoration: none; }
      .right { text-align: right; }
    </style>
    """
    # Простая сортировка таблиц
    script = """
    <script>
    // Very small table sorter
    function sortTable(tblId, colIdx, numeric) {
      const tbl = document.getElementById(tblId);
      const tbody = tbl.tBodies[0];
      const rows = Array.from(tbody.querySelectorAll("tr"));
      const asc = tbl.getAttribute("data-sort-dir") !== "asc";
      rows.sort((a, b) => {
        let av = a.children[colIdx].innerText;
        let bv = b.children[colIdx].innerText;
        if (numeric) {
          av = parseFloat(av.replace(',', '.'));
          bv = parseFloat(bv.replace(',', '.'));
          if (isNaN(av)) av = -Infinity;
          if (isNaN(bv)) bv = -Infinity;
        }
        if (av < bv) return asc ? -1 : 1;
        if (av > bv) return asc ? 1 : -1;
        return 0;
      });
      tbody.innerHTML = "";
      rows.forEach(r => tbody.appendChild(r));
      tbl.setAttribute("data-sort-dir", asc ? "asc" : "desc");
    }
    </script>
    """

    # Сборка сводной таблицы
    cmp_rows = []
    for row in comparison:
        anchor = f"card-{escape(row['name']).replace(' ', '_')}"
        cmp_rows.append(f"""
          <tr>
            <td><a href="#{anchor}">{escape(row['name'])}</a></td>
            <td>{escape(str(row['target']))}</td>
            <td class="right">{row['rows']}</td>
            <td class="right">{row['features']}</td>
            <td class="right">{_fmt(row['real_score'])}</td>
            <td class="right">{_fmt(row['synt_score'])}</td>
            <td class="right">{_fmt(row['gap'])}</td>
            <td class="right">{_fmt(row['mean_jsd'])}</td>
          </tr>
        """)

    cmp_table = f"""
    <div class="cmp-table">
      <table id="cmp" data-sort-dir="asc">
        <thead>
          <tr>
            <th onclick="sortTable('cmp', 0, false)">Модель</th>
            <th onclick="sortTable('cmp', 1, false)">Целевая</th>
            <th class="right" onclick="sortTable('cmp', 2, true)">Строк</th>
            <th class="right" onclick="sortTable('cmp', 3, true)">Фич</th>
            <th class="right" onclick="sortTable('cmp', 4, true)">Acc реальных</th>
            <th class="right" onclick="sortTable('cmp', 5, true)">Acc синтет.</th>
            <th class="right" onclick="sortTable('cmp', 6, true)">Gap</th>
            <th class="right" onclick="sortTable('cmp', 7, true)">Mean JSD</th>
          </tr>
        </thead>
        <tbody>
          {''.join(cmp_rows)}
        </tbody>
      </table>
    </div>
    """

    # Карточки по каждому Card
    cards_html = []
    for name, card in cards.items():
        anchor = f"card-{escape(name).replace(' ', '_')}"
        real_score = getattr(card, "real_score", None)
        synt_score = _safe_get_synt_score(card)
        jsd_df = getattr(card, "jensen_shannon_divergence", None)

        # бейджи
        badges = []
        if real_score is not None:
            badges.append(f'<span class="badge ok">Acc(real): {_fmt(real_score)}</span>')
        if synt_score is not None:
            badges.append(
                f'<span class="badge {"ok" if real_score and synt_score and synt_score >= real_score * 0.9 else "warn"}">Acc(synt): {_fmt(synt_score)}</span>')
        # nrows/ncols
        try:
            badges.append(f'<span class="badge">Rows: {len(card.real_data)}</span>')
            badges.append(f'<span class="badge">Cols: {card.real_data.shape[1]}</span>')
        except Exception:
            pass

        # картинка обучения
        img_html = _embed_image_base64(getattr(card, "schedule_path", ""))

        # таблица метрик
        table_html = ""
        if jsd_df is not None:
            try:
                df = jsd_df.copy()
                if "jensen_shannon_divergence" in df.columns:
                    df = df.sort_values("jensen_shannon_divergence", ascending=False)
                if top_k:
                    df = df.head(top_k)
                # Ручной рендер чтобы не тащить внешние стили
                rows = []
                for _, r in df.iterrows():
                    col = escape(str(r.get("column", "")))
                    val = r.get("jensen_shannon_divergence", None)
                    rows.append(f"<tr><td>{col}</td><td class='right'>{_fmt(val)}</td></tr>")
                table_html = f"""
                <table id="tbl-{anchor}" data-sort-dir="asc">
                  <thead>
                    <tr>
                      <th onclick="sortTable('tbl-{anchor}', 0, false)">Колонка</th>
                      <th class='right' onclick="sortTable('tbl-{anchor}', 1, true)">JSD</th>
                    </tr>
                  </thead>
                  <tbody>
                    {''.join(rows)}
                  </tbody>
                </table>
                """
            except Exception:
                table_html = "<div class='meta'>Нет доступных метрик.</div>"

        cards_html.append(f"""
        <section class="card" id="{anchor}">
          <h2>{escape(name)}</h2>
          <div class="meta">target: {escape(str(getattr(card, "target", "")))}</div>
          <div class="badges">{''.join(badges)}</div>
          {img_html}
          <details open>
            <summary>Метрики распределений (Jensen–Shannon) — top {top_k}</summary>
            {table_html}
          </details>
        </section>
        """)

    html = f"""<!doctype html>
    <html lang="ru">
      <meta charset="utf-8">
      <meta name="viewport" content="width=device-width, initial-scale=1">
      <title>Cards report</title>
      {style}
      {script}
      <body>
        <h1>Сводный отчёт по моделям</h1>
        <div class="subtitle">Сгенерировано: {escape(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}</div>
        <h3>Сравнение</h3>
        {cmp_table}
        <h3>Карточки</h3>
        <div class="grid">
          {''.join(cards_html)}
        </div>
        <div class="footer"><span class="hint">Подсказка: кликайте по заголовкам таблиц, чтобы сортировать.</span></div>
      </body>
    </html>"""

    output_path = str(output_path)
    with open(output_path, "w", encoding="utf-8") as f:
        f.write(html)
    return output_path
