# Лабораторная работа №4 — Байесовские сети
Dataset: **Mushroom**

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
from sklearn.preprocessing import LabelEncoder

from pgmpy.estimators import PC
from pgmpy.models import BayesianNetwork
from pgmpy.estimators import BayesianEstimator
from pgmpy.inference import VariableElimination

# 1. Данные
df = pd.read_csv("mushrooms.csv")
df = df.apply(LabelEncoder().fit_transform)
print("Размер датасета:", df.shape)

# 2. Поиск структуры (PC, chi_square)
pc = PC(data=df)                       
dag = pc.estimate(
    ci_test="chi_square",              
    return_type="dag",
    significance_level=0.01,
    show_progress=False
)

print("Найдённые связи:")
print(list(dag.edges()))

# 3. Обучение параметров
model = BayesianNetwork(dag.edges())
model.fit(df, estimator=BayesianEstimator, prior_type="K2")
print("Параметры сети обучены.")

# 4. Визуализация графа
plt.figure(figsize=(12, 10))
G = nx.DiGraph()
G.add_edges_from(model.edges())
nx.draw(G, with_labels=True, node_size=1800, font_size=9)
plt.title("Структура Байесовской сети (PC-алгоритм)")
plt.show()

# 5. Heatmap связей (матрица смежности)
nodes = list(df.columns)
adj = pd.DataFrame(0, index=nodes, columns=nodes)
for parent, child in model.edges():
    adj.loc[parent, child] = 1

plt.figure(figsize=(14, 10))
sns.heatmap(adj, cmap="viridis", annot=True, fmt="d")
plt.title("Heatmap силы связей (Adjacency Matrix)")
plt.show()

# 6. Inference — 3 примера
infer = VariableElimination(model)

q1 = infer.query(["class"], evidence={"odor": 1})
q2 = infer.query(["class"], evidence={"cap-color": 3, "gill-color": 2})
q3 = infer.query(["class"], evidence={"stalk-shape": 1})

print("\nInference 1 (odor=1):")
print(q1)
print("\nInference 2 (cap-color=3, gill-color=2):")
print(q2)
print("\nInference 3 (stalk-shape=1):")
print(q3)

# 7. Baseline
baseline = df["class"].value_counts(normalize=True).max()
print("\nBaseline (доля самого частого класса):", baseline)


INFO:pgmpy: Datatype (N=numerical, C=Categorical Unordered, O=Categorical Ordered) inferred from data: 
 {'class': 'N', 'cap-shape': 'N', 'cap-surface': 'N', 'cap-color': 'N', 'bruises': 'N', 'odor': 'N', 'gill-attachment': 'N', 'gill-spacing': 'N', 'gill-size': 'N', 'gill-color': 'N', 'stalk-shape': 'N', 'stalk-root': 'N', 'stalk-surface-above-ring': 'N', 'stalk-surface-below-ring': 'N', 'stalk-color-above-ring': 'N', 'stalk-color-below-ring': 'N', 'veil-type': 'N', 'veil-color': 'N', 'ring-number': 'N', 'ring-type': 'N', 'spore-print-color': 'N', 'population': 'N', 'habitat': 'N'}


Размер датасета: (8124, 23)
