In [10]:
#!wget https://github.com/SimoneRitt/IACModel_CCM/tree/main/IAC_plotting.py
import IAC_plotting

import pandas as pd
import numpy as np
from sklearn import datasets

# Original Example: Jets and Sharks

In [12]:
#!wget https://github.com/SimoneRitt/IACModel_CCM/tree/main/data/jets_sharks.csv

In [13]:
test1 = pd.read_csv('data/jets_sharks.csv')
test1

Unnamed: 0,Name,Gang,Age,Edu,Mar,Occupation
0,Art,Jets,40's,J.H.,Sing.,Pusher
1,Al,Jets,30's,J.H.,Mar.,Burglar
2,Sam,Jets,20's,COL.,Sing.,Bookie
3,Clyde,Jets,40's,J.H.,Sing.,Bookie
4,Mike,Jets,30's,J.H.,Sing.,Bookie
5,Jim,Jets,20's,J.H.,Div.,Burglar
6,Greg,Jets,20's,H.S.,Mar.,Pusher
7,John,Jets,20's,J.H.,Mar.,Burglar
8,Doug,Jets,30's,H.S.,Sing.,Bookie
9,Lance,Jets,20's,J.H.,Mar.,Burglar


In [14]:
IAC_plotting.plot(test1)

# New Example: Iris Dataset

## Creating DataFrame

For effective visualization, the DataFrame should have relatively few unique column values, as these will become the nodes of each pool. It is recommended that users create bins for continuous values to avoid having an excess of nodes.

In [17]:
iris = datasets.load_iris()
test2 = pd.DataFrame(data = iris['data'],
                    columns = iris['feature_names'])
test2 = pd.concat([test2, pd.Series(iris['target'], name='target_names')], axis='columns')
test2['target_names'] = test2['target_names'].apply(func=lambda x: iris['target_names'][x])
test2

Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),target_names
0,5.1,3.5,1.4,0.2,setosa
1,4.9,3.0,1.4,0.2,setosa
2,4.7,3.2,1.3,0.2,setosa
3,4.6,3.1,1.5,0.2,setosa
4,5.0,3.6,1.4,0.2,setosa
...,...,...,...,...,...
145,6.7,3.0,5.2,2.3,virginica
146,6.3,2.5,5.0,1.9,virginica
147,6.5,3.0,5.2,2.0,virginica
148,6.2,3.4,5.4,2.3,virginica


In [18]:
for c in test2.columns:
    if test2[c].dtype != object:
        print(f"Range of {c} = {test2[c].max() - test2[c].min()}")

Range of sepal length (cm) = 3.6000000000000005
Range of sepal width (cm) = 2.4000000000000004
Range of petal length (cm) = 5.9
Range of petal width (cm) = 2.4


Because the ranges are relatively small, we can make buckets for nodes as [x, x+1).

In [20]:
import math 

def create_buckets_1(x):
    lower_bound = math.floor(x)
    return f"[{lower_bound}, {lower_bound+1})"

for c in test2.columns:
    if test2[c].dtype != object:
        test2[c] = test2[c].apply(create_buckets_1)

test2

Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),target_names
0,"[5, 6)","[3, 4)","[1, 2)","[0, 1)",setosa
1,"[4, 5)","[3, 4)","[1, 2)","[0, 1)",setosa
2,"[4, 5)","[3, 4)","[1, 2)","[0, 1)",setosa
3,"[4, 5)","[3, 4)","[1, 2)","[0, 1)",setosa
4,"[5, 6)","[3, 4)","[1, 2)","[0, 1)",setosa
...,...,...,...,...,...
145,"[6, 7)","[3, 4)","[5, 6)","[2, 3)",virginica
146,"[6, 7)","[2, 3)","[5, 6)","[1, 2)",virginica
147,"[6, 7)","[3, 4)","[5, 6)","[2, 3)",virginica
148,"[6, 7)","[3, 4)","[5, 6)","[2, 3)",virginica


In [21]:
IAC_plotting.plot(test2, hidden_state='target_names')

# New Example: Adult Census Income

A dataset used to predict whether income exceeds $50k a year based on census data. 

**Citation:** https://www.kaggle.com/datasets/uciml/adult-census-income/data

## Creating DataFrame

For the best results, each hidden unit should map to a single visible unit for each pool. Multiple hidden units can map to the same visible unit, but a single hidden unit should not make more than one connection to another pool.

In McClelland's original Jets and Sharks data, this would be equivalent to ensuring that a single individual is not both Married and Divorced or in his 20s and 40s. 

To achieve this, we must take a subset of our data:

In [32]:
test3 = pd.read_csv('data/adult.csv')
print(test3.shape)
test3.head()

(32561, 15)


Unnamed: 0,age,workclass,fnlwgt,education,education.num,marital.status,occupation,relationship,race,sex,capital.gain,capital.loss,hours.per.week,native.country,income
0,90,?,77053,HS-grad,9,Widowed,?,Not-in-family,White,Female,0,4356,40,United-States,<=50K
1,82,Private,132870,HS-grad,9,Widowed,Exec-managerial,Not-in-family,White,Female,0,4356,18,United-States,<=50K
2,66,?,186061,Some-college,10,Widowed,?,Unmarried,Black,Female,0,4356,40,United-States,<=50K
3,54,Private,140359,7th-8th,4,Divorced,Machine-op-inspct,Unmarried,White,Female,0,3900,40,United-States,<=50K
4,41,Private,264663,Some-college,10,Separated,Prof-specialty,Own-child,White,Female,0,3900,40,United-States,<=50K


Data Cleaning (removing null values)

In [35]:
test3.replace({'?': np.nan}, inplace=True)
test3.dropna(inplace=True)
test3.shape

(30162, 15)

Selecting Columns to Retain (setting pools)

In [38]:
print('UNIQUE VALUES =', len(test3.workclass.value_counts()))
test3.workclass.value_counts()

UNIQUE VALUES = 7


workclass
Private             22286
Self-emp-not-inc     2499
Local-gov            2067
State-gov            1279
Self-emp-inc         1074
Federal-gov           943
Without-pay            14
Name: count, dtype: int64

In [40]:
test3.age.min(), test3.age.max()

(np.int64(17), np.int64(90))

In [42]:
print('UNIQUE VALUES =', len(test3.education.value_counts()))
test3.education.value_counts()

UNIQUE VALUES = 16


education
HS-grad         9840
Some-college    6678
Bachelors       5044
Masters         1627
Assoc-voc       1307
11th            1048
Assoc-acdm      1008
10th             820
7th-8th          557
Prof-school      542
9th              455
12th             377
Doctorate        375
5th-6th          288
1st-4th          151
Preschool         45
Name: count, dtype: int64

In [44]:
test3['marital.status'].value_counts()

marital.status
Married-civ-spouse       14065
Never-married             9726
Divorced                  4214
Separated                  939
Widowed                    827
Married-spouse-absent      370
Married-AF-spouse           21
Name: count, dtype: int64

In [46]:
# Replace types of Married with Married
test3['marital.status'] = test3['marital.status'].replace({'Married-civ-spouse':'Married', 
                                 'Married-spouse-absent':'Married', 
                                 'Married-AF-spouse':'Married'})
print('UNIQUE VALUES =', len(test3['marital.status'].value_counts()))
test3['marital.status'].value_counts()

UNIQUE VALUES = 5


marital.status
Married          14456
Never-married     9726
Divorced          4214
Separated          939
Widowed            827
Name: count, dtype: int64

In [48]:
print('UNIQUE VALUES =', len(test3['occupation'].value_counts()))
test3['occupation'].value_counts()

UNIQUE VALUES = 14


occupation
Prof-specialty       4038
Craft-repair         4030
Exec-managerial      3992
Adm-clerical         3721
Sales                3584
Other-service        3212
Machine-op-inspct    1966
Transport-moving     1572
Handlers-cleaners    1350
Farming-fishing       989
Tech-support          912
Protective-serv       644
Priv-house-serv       143
Armed-Forces            9
Name: count, dtype: int64

In [50]:
print('UNIQUE VALUES =', len(test3['relationship'].value_counts()))
test3['relationship'].value_counts()

UNIQUE VALUES = 6


relationship
Husband           12463
Not-in-family      7726
Own-child          4466
Unmarried          3212
Wife               1406
Other-relative      889
Name: count, dtype: int64

In [52]:
print('UNIQUE VALUES =', len(test3['race'].value_counts()))
test3['race'].value_counts()

UNIQUE VALUES = 5


race
White                 25933
Black                  2817
Asian-Pac-Islander      895
Amer-Indian-Eskimo      286
Other                   231
Name: count, dtype: int64

In [54]:
print('UNIQUE VALUES =', len(test3['sex'].value_counts()))
test3['sex'].value_counts()

UNIQUE VALUES = 2


sex
Male      20380
Female     9782
Name: count, dtype: int64

In [56]:
print('UNIQUE VALUES =', len(test3['income'].value_counts()))
test3['income'].value_counts()

UNIQUE VALUES = 2


income
<=50K    22654
>50K      7508
Name: count, dtype: int64

The largest number of unique values is for Education Level with 16 values. We can therefore sample 16 rows, which will become our 16 hidden units:

In [59]:
test3 = test3.groupby('education').sample(n=1, random_state=0).reset_index(drop=True)
test3

Unnamed: 0,age,workclass,fnlwgt,education,education.num,marital.status,occupation,relationship,race,sex,capital.gain,capital.loss,hours.per.week,native.country,income
0,64,Self-emp-not-inc,388625,10th,6,Married,Prof-specialty,Husband,White,Male,0,0,10,United-States,>50K
1,36,Private,292570,11th,7,Never-married,Machine-op-inspct,Unmarried,White,Female,0,0,40,United-States,<=50K
2,20,Private,224238,12th,8,Never-married,Craft-repair,Own-child,White,Male,0,0,40,United-States,<=50K
3,50,Private,193374,1st-4th,2,Married,Craft-repair,Unmarried,White,Male,0,0,40,United-States,<=50K
4,48,Private,315423,5th-6th,3,Married,Transport-moving,Husband,White,Male,0,0,40,United-States,<=50K
5,38,Self-emp-not-inc,282461,7th-8th,4,Married,Sales,Husband,White,Male,0,0,35,United-States,>50K
6,37,Private,758700,9th,5,Married,Handlers-cleaners,Husband,White,Male,3781,0,50,Mexico,<=50K
7,41,Private,171615,Assoc-acdm,12,Married,Tech-support,Husband,White,Male,0,0,45,United-States,>50K
8,22,Private,171419,Assoc-voc,11,Never-married,Exec-managerial,Unmarried,Asian-Pac-Islander,Male,0,0,40,South,<=50K
9,50,Private,134766,Bachelors,13,Married,Exec-managerial,Husband,White,Male,0,1902,50,United-States,>50K


To simplify the visualization, we will remove the columns `fnlwgt`, `education.num`, `relationship`, `capital.gain`, `capital.loss`, `race`, and `native.country` -- leaving 8 pools.

Although interesting to study, race was dropped as the majority of the dataframe is "White" with only a single non-White value.

The final pools are:
- age
- work class
- education
- marital status
- occupation
- sex
- hours per week
- income

In order to minimize the amount of visualized node and increase readability, we will bucket the categories `age` and `hours.per.week`.

In [62]:
# Removing columns
test3.drop(['fnlwgt', 'education.num', 'relationship', 
            'capital.gain', 'capital.loss', 'race', 'native.country'], axis=1, inplace=True)

# Renaming Pools
test3.columns = ['age', 'work class', 'education', 'marital status',
                'occupation', 'sex', 'hours per week', 'income']

# Bucketing age and hours per week
def create_buckets_10(x):
    lower_bound = (x//10) * 10
    return f"[{lower_bound}, {lower_bound+10})"

test3['age'] = test3['age'].apply(create_buckets_10)
test3['hours per week'] = test3['hours per week'].apply(create_buckets_10)

test3

Unnamed: 0,age,work class,education,marital status,occupation,sex,hours per week,income
0,"[60, 70)",Self-emp-not-inc,10th,Married,Prof-specialty,Male,"[10, 20)",>50K
1,"[30, 40)",Private,11th,Never-married,Machine-op-inspct,Female,"[40, 50)",<=50K
2,"[20, 30)",Private,12th,Never-married,Craft-repair,Male,"[40, 50)",<=50K
3,"[50, 60)",Private,1st-4th,Married,Craft-repair,Male,"[40, 50)",<=50K
4,"[40, 50)",Private,5th-6th,Married,Transport-moving,Male,"[40, 50)",<=50K
5,"[30, 40)",Self-emp-not-inc,7th-8th,Married,Sales,Male,"[30, 40)",>50K
6,"[30, 40)",Private,9th,Married,Handlers-cleaners,Male,"[50, 60)",<=50K
7,"[40, 50)",Private,Assoc-acdm,Married,Tech-support,Male,"[40, 50)",>50K
8,"[20, 30)",Private,Assoc-voc,Never-married,Exec-managerial,Male,"[40, 50)",<=50K
9,"[50, 60)",Private,Bachelors,Married,Exec-managerial,Male,"[50, 60)",>50K


Each unit has a label corresponding to its value in the DataFrame. Longer text like "self-emp-not-inc" may be difficult to read and overlap with surrounding labels. To account for this, the final step of preparation is adjusting labels where needed.

In [65]:
test3['work class'] = test3['work class'].replace({'Self-emp-not-inc':'Self-emp'})
test3['education'] = test3['education'].replace({'Assoc-acdm': 'Assoc.',
                                                 'Assoc-voc': 'Assoc.',
                                                 'Some-college':'College'})
test3['marital status'] = test3['marital status'].replace({'Never-married':'Single'})
test3['occupation'] = test3['occupation'].replace({'Prof-specialty':'Specialty',
                                                   'Machine-op-inspct':'Machinery',
                                                   'Craft-repair':'Craft',
                                                   'Transport-moving':'Transport',
                                                   'Handlers-cleaners':'Cleaner',
                                                   'Tech-support':'IT',
                                                   'Exec-managerial':'Manager',
                                                   'Other-service':'Other'})

# removing one row as we've merged 'Assoc-acdm' and 'Assoc-voc'
test3.drop(test3[test3.education == 'Assoc.'].sample(1, random_state=0).index, axis=0, inplace=True)

In [67]:
test3 = test3.reset_index(drop=True)
test3

Unnamed: 0,age,work class,education,marital status,occupation,sex,hours per week,income
0,"[60, 70)",Self-emp,10th,Married,Specialty,Male,"[10, 20)",>50K
1,"[30, 40)",Private,11th,Single,Machinery,Female,"[40, 50)",<=50K
2,"[20, 30)",Private,12th,Single,Craft,Male,"[40, 50)",<=50K
3,"[50, 60)",Private,1st-4th,Married,Craft,Male,"[40, 50)",<=50K
4,"[40, 50)",Private,5th-6th,Married,Transport,Male,"[40, 50)",<=50K
5,"[30, 40)",Self-emp,7th-8th,Married,Sales,Male,"[30, 40)",>50K
6,"[30, 40)",Private,9th,Married,Cleaner,Male,"[50, 60)",<=50K
7,"[40, 50)",Private,Assoc.,Married,IT,Male,"[40, 50)",>50K
8,"[50, 60)",Private,Bachelors,Married,Manager,Male,"[50, 60)",>50K
9,"[40, 50)",Private,Doctorate,Divorced,Sales,Female,"[50, 60)",>50K


In [69]:
IAC_plotting.plot(test3, hidden_state='education')