In [1]:
# ! pip install bamt

In [2]:
from dataclasses import dataclass
from enum import Enum, auto
from typing import Self
from uuid import uuid4
from bamt.preprocessors import Preprocessor
import pandas as pd
from sklearn import preprocessing as pp
from bamt.networks import HybridBN

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from scipy import stats


class Sex(Enum):
    male = auto()
    female = auto()


class MaritalStatus(Enum):
    single = auto()
    maried = auto()


class Education(Enum):
    secondary = auto()  # среднее
    secondary_proffesional = auto()
    higher = auto()


@dataclass
class Agent:
    id = uuid4()
    age: int
    sex: Sex
    salary: int
    education: Education | str
    mortgage_dept: int | None = None
    marital_status: MaritalStatus | str | None = None
    children: list[Self] | None = None
    parents: list[Self] | None = None

In [3]:
import pandas as pd


df = pd.read_csv("salary.csv")

In [4]:
df.columns

Index(['age', 'workclass', 'fnlwgt', 'education', 'education-num',
       'marital-status', 'occupation', 'relationship', 'race', 'sex',
       'capital-gain', 'capital-loss', 'hours-per-week', 'native-country',
       'salary'],
      dtype='object')

In [5]:
def parse_sex(v):
    if v==" Male":
        return Sex.male
    if v==" Female":
        return Sex.female
    return None
population = []
for index, row in df.iterrows():
    agent = Agent(age=row['age'],
                  sex=parse_sex(row["sex"]),
                  education=row["education"],
                  salary=1
                  
                  )
    population.append(agent)

In [6]:
population[23]

Agent(age=43, sex=<Sex.male: 1>, salary=1, education=' 11th', mortgage_dept=None, marital_status=None, children=None, parents=None)

In [7]:
rlms = pd.read_excel("Data_RLMS.xlsx")
rlms.columns

Index(['idind', 'psu', 'status', 'age', 'male', 'industry', 'lnwage', 'public',
       'internet', 'children', 'urban', 'educ', 'id1', 'id2', 'id3', 'id4',
       'id5', 'id6', 'id7', 'id8', 'id9', 'id10', 'id11', 'id12', 'id13',
       'id14', 'id15', 'id16', 'id17', 'id18', 'id19', 'id20', 'id21', 'id22',
       'id23', 'id24', 'id25', 'id26', 'id27', 'id28', 'id29', 'id30', 'id31',
       'id32', 'id33', 'id34', 'id35', 'id36', 'id37', 'id38', 'id39', 'id40',
       'id41', 'id42', 'id43', 'id44', 'id45', 'id46', 'id47', 'id48', 'id49',
       'id50', 'id51', 'id52', 'id53', 'id54', 'id55', 'id56', 'id57', 'id58',
       'id59', 'id60', 'id61', 'id62', 'id63', 'id64', 'id65', 'id66', 'id67',
       'id68', 'id69', 'id70', 'id71', 'id72', 'id73', 'id74', 'id75'],
      dtype='object')

In [31]:
data = rlms[["age", "male", "industry", "lnwage",
             "public", "children", "urban", "educ"]]

data = rlms[["age",  "lnwage", "children", 
             ]]

data["public"] = rlms.public.astype(str)
data["male"] = rlms.male.astype(str)
data["industry"] = rlms.industry.astype(str)
# data["urban"] = rlms.urban.astype(str)
# data["educ"] = rlms.educ.astype(str)

In [32]:
data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 3871 entries, 0 to 3870
Data columns (total 6 columns):
 #   Column    Non-Null Count  Dtype  
---  ------    --------------  -----  
 0   age       3871 non-null   int64  
 1   lnwage    3871 non-null   float64
 2   children  3871 non-null   int64  
 3   public    3871 non-null   object 
 4   male      3871 non-null   object 
 5   industry  3871 non-null   object 
dtypes: float64(1), int64(2), object(3)
memory usage: 181.6+ KB


In [33]:
# set encoder and discretizer
encoder = pp.LabelEncoder()
discretizer = pp.KBinsDiscretizer(n_bins=5, encode='ordinal', strategy='uniform')

# create preprocessor object with encoder and discretizer
p = Preprocessor([('encoder', encoder), ('discretizer', discretizer)])

# discretize data for structure learning
discretized_data, est = p.apply(data)

# get information about data
info = p.info

In [34]:
info

{'types': {'age': 'disc_num',
  'lnwage': 'cont',
  'children': 'disc_num',
  'public': 'disc',
  'male': 'disc',
  'industry': 'disc'},
 'signs': {'lnwage': 'pos'}}

In [35]:
# initialize network object
bn = HybridBN(use_mixture=True, has_logit=True)

# add nodes to network
bn.add_nodes(info)

# using mutual information as scoring function for structure learning
bn.add_edges(discretized_data, 
            #  scoring_function=('MI',)
             )

# or use evolutionary algorithm to learn structure

bn.add_edges(discretized_data)

bn.fit_parameters(data)

  0%|          | 0/1000000 [00:00<?, ?it/s]

  0%|          | 0/1000000 [00:00<?, ?it/s]

In [36]:
bn.validate(info)

True

In [15]:
bn.save("pop.json")

True

In [57]:
sampled_data = bn.sample(10_000, progress_bar=False)
sampled_data = sampled_data[sampled_data.age!='nan']

In [58]:
sampled_data.age.unique()

array(['37', '51', '33', '27', '55', '42', '45', '54', '32', '25', '28',
       '56', '57', '30', '46', '39', '44', '23', '40', '47', '43', '36',
       '41', '29', '50', '52', '26', '20', '58', '53', '19', '31', '35',
       '38', '34', '48', '21', '49', '22', '59', '24', '18', '17'],
      dtype=object)

In [61]:
sampled_data.children = sampled_data.children.astype(int)
sampled_data.public = sampled_data.public.astype(int)
sampled_data.male = sampled_data.male.astype(int)
sampled_data.age = sampled_data.age.astype(int)
sampled_data.industry = sampled_data.industry.astype(str)

In [62]:
sampled_data.info()

<class 'pandas.core.frame.DataFrame'>
Index: 9753 entries, 0 to 9999
Data columns (total 6 columns):
 #   Column    Non-Null Count  Dtype  
---  ------    --------------  -----  
 0   industry  9753 non-null   object 
 1   public    9753 non-null   int64  
 2   male      9753 non-null   int64  
 3   lnwage    9753 non-null   float64
 4   children  9753 non-null   int64  
 5   age       9753 non-null   int64  
dtypes: float64(1), int64(4), object(1)
memory usage: 533.4+ KB


In [63]:
data.age.info()

<class 'pandas.core.series.Series'>
RangeIndex: 3871 entries, 0 to 3870
Series name: age
Non-Null Count  Dtype
--------------  -----
3871 non-null   int64
dtypes: int64(1)
memory usage: 30.4 KB


In [66]:
fig = go.Figure()
fig.update_layout(
    title="Age")

fig.add_trace(go.Histogram(x=sorted(sampled_data.age), histnorm='probability density', name=f"sampled"))
fig.add_trace(go.Histogram(x=sorted(data.age), histnorm='probability density', name=f"original"))

In [67]:
fig = go.Figure()
fig = go.Figure()
fig.update_layout(
    title="children")
fig.add_trace(go.Histogram(x=sampled_data.children,
              histnorm='probability density', name=f"sampled"))
fig.add_trace(go.Histogram(x=data.children,
              histnorm='probability density', name=f"original"))