In [None]:
from thirdai import bolt
import os
import pandas as pd

CENSUS_INCOME_BASE_DOWNLOAD_URL = (
    "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/"
)

TABULAR_TRAIN_FILE = "./census_income_train.csv"

column_names = [
    "age", "workclass", "fnlwgt", "education", "education-num", "marital-status", 
    "occupation", "relationship", "race", "sex", "capital-gain", "capital-loss", "hours-per-week", "native-country", "label"
]

def download_census_income_dataset():
    if not os.path.exists(TABULAR_TRAIN_FILE):
        os.system(
            f"curl {CENSUS_INCOME_BASE_DOWNLOAD_URL}adult.data --output {TABULAR_TRAIN_FILE}"
        )

def reformat_train_csv():
    with open(TABULAR_TRAIN_FILE, "r") as file:
        data = file.readlines()
    with open(TABULAR_TRAIN_FILE, "w") as file:
        file.write(",".join(column_names) + "\n")
        file.writelines([line.replace(", ", ",") for line in data])


download_census_income_dataset()
reformat_train_csv()

: 

In [None]:
tabular_model = bolt.UniversalDeepTransformer(
    data_types={
        "age": bolt.types.numerical(range=(17, 90)),
        "workclass": bolt.types.categorical(n_unique_classes=9),
        "fnlwgt": bolt.types.numerical(range=(12285, 1484705)),
        "education": bolt.types.categorical(n_unique_classes=16),
        "education-num": bolt.types.categorical(n_unique_classes=16),
        "marital-status": bolt.types.categorical(n_unique_classes=7),
        "occupation": bolt.types.categorical(n_unique_classes=15),
        "relationship": bolt.types.categorical(n_unique_classes=6),
        "race": bolt.types.categorical(n_unique_classes=5),
        "sex": bolt.types.categorical(n_unique_classes=2),
        "capital-gain": bolt.types.numerical(range=(0, 99999)),
        "capital-loss": bolt.types.numerical(range=(0, 4356)),
        "hours-per-week": bolt.types.numerical(range=(1, 99)),
        "native-country": bolt.types.categorical(n_unique_classes=42),
        "label": bolt.types.categorical(n_unique_classes=2),
    },
    target="label"
)

train_config = (bolt.TrainConfig(epochs=5, learning_rate=0.01)
                    .with_metrics(["categorical_accuracy"]))

tabular_model.train(TABULR_TRAIN_FILE, train_config)


input_1 (Input): dim=100000
input_1 -> fc_1 (FullyConnected): dim=512, sparsity=1, act_func=ReLU
fc_1 -> fc_2 (FullyConnected): dim=2, sparsity=1, act_func=Softmax

Loading vectors from './census_income_train.csv'


ValueError: [ThreadSafeVocabulary] Expected 2 unique strings but found more.

: 

In [None]:
#25,Private,226802,11th,7,Never-married,Machine-op-inspct,Own-child,Black,Male,0,0,40,United-States,<=50K

dict_for_explainability = {
    "age": "25",
    "workclass": "Private",
    "fnlwgt": "226802",
    "education": "11th",
    "education-num": "7",
    "marital-status": "Never-married",
    "occupation": "Machine-op-inspct",
    "relationship": "Own-child",
    "race": "Black",
    "sex": " Male",
    "capital-gain": "0",
    "capital-loss": "0",
    "hours-per-week": "40",
    "native-country": "United-States"
}


explanations = tabular_model.explain(input_sample=dict_for_explainability, target_class="<=50K",)
for explanation in explanations:
    print(explanation)

column_name: "relationship" | keyword: "Own-child" | percentage_significance: 19.2944
column_name: "age" | keyword: "25" | percentage_significance: 18.0571
column_name: "marital-status" | keyword: "Never-married" | percentage_significance: 16.3937
column_name: "capital-gain" | keyword: "0" | percentage_significance: 11.1975
column_name: "occupation" | keyword: "Machine-op-inspct" | percentage_significance: 11.0844
column_name: "education" | keyword: "11th" | percentage_significance: 6.69032
column_name: "race" | keyword: "Black" | percentage_significance: 6.5515
column_name: "education-num" | keyword: "7" | percentage_significance: 6.49143
column_name: "capital-loss" | keyword: "0" | percentage_significance: 1.771
column_name: "hours-per-week" | keyword: "40" | percentage_significance: 1.51048
column_name: "fnlwgt" | keyword: "226802" | percentage_significance: -0.476827
column_name: "native-country" | keyword: "United-States" | percentage_significance: -0.353798
column_name: "workclas

: 

In [None]:
os.remove(TABULAR_TRAIN_FILE)


NameError: name 'TABULAR_TRAIN_FILE' is not defined

: 