In [1]:
import pandas as pd

CSV_HEADERS = ['age', 'workclass', 'fnlwgt', 'education', 'education_num', 'marital_status',
    'occupation', 'relationship', 'race', 'gender', 'capital_gain', 'capital_loss',
    'hours_per_week', 'native_country', 'income_bracket']

train_url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data'
df_train = pd.read_csv(train_url, header=None, names=CSV_HEADERS)

test_url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test'
df_test = pd.read_csv(test_url, header=None, names=CSV_HEADERS, skiprows=1)

def load_data(df):
    df = df.copy().drop(columns=['fnlwgt', 'education_num'])\
        .reset_index(drop=True)
    
    numeric_cols = ['capital_gain', 'capital_loss', 'hours_per_week']
    X_num = df[numeric_cols].astype('float32')

    categoric_cols = [c for c in df.columns if c not in numeric_cols]
    X_cat = df[categoric_cols].astype(str).apply(lambda s: s.str.strip())
    y = X_cat['income_bracket'].str.replace('.', '')
    X_cat = X_cat.drop(columns=['income_bracket'])
    
    return X_num, X_cat, y

X_num_train, X_cat_train, y_train = load_data(df=df_train)
X_num_test, X_cat_test, y_test = load_data(df=df_test)

In [2]:
from sklearn import preprocessing

lb = preprocessing.LabelBinarizer()
lb.fit(y=y_train)
y_train = lb.transform(y=y_train).squeeze()
y_test = lb.transform(y=y_test).squeeze()

scaler = preprocessing.StandardScaler()
scaler.fit(X=X_num_train)
X_num_train = scaler.fit_transform(X=X_num_train)
X_num_test = scaler.transform(X=X_num_test)

encoder = preprocessing.OrdinalEncoder(unknown_value=-1,
    handle_unknown='use_encoded_value')
encoder.fit(X=X_cat_train)
X_cat_train = encoder.transform(X=X_cat_train)
X_cat_test = encoder.transform(X=X_cat_test)

num_features = X_num_train.shape[1]
cat_cardinalities = (X_cat_train.max(axis=0) + 1).astype(int).tolist()

In [3]:
import torch
torch.manual_seed(seed=42)

class CensusDataset(torch.utils.data.Dataset):
    def __init__(self, X_num, X_cat, y):
        self.X_num = torch.tensor(data=X_num, dtype=torch.float32)
        self.X_cat = torch.tensor(data=X_cat, dtype=torch.long)
        self.y = torch.tensor(data=y, dtype=torch.float32)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.X_num[idx], self.X_cat[idx], self.y[idx]

ds_train = CensusDataset(X_num=X_num_train, X_cat=X_cat_train, y=y_train)
dl_train = torch.utils.data.DataLoader(dataset=ds_train, batch_size=256, shuffle=True)

ds_temp = CensusDataset(X_num=X_num_test, X_cat=X_cat_test, y=y_test)
ds_val, ds_test = torch.utils.data.random_split(dataset=ds_temp, lengths=[0.5, 0.5],
    generator=torch.Generator().manual_seed(42))
dl_val = torch.utils.data.DataLoader(dataset=ds_val, batch_size=256)
dl_test = torch.utils.data.DataLoader(dataset=ds_test, batch_size=256)

In [44]:
import torch
torch.manual_seed(seed=42)

class CensusClassifier(torch.nn.Module):
    def __init__(self, cat_cardinalities):
        super().__init__()
        self.emb_layers = torch.nn.ModuleList(modules=[
            torch.nn.Embedding(num_embeddings=c+1, embedding_dim=min(8, (c+1)//2))
            for c in cat_cardinalities
        ])
        emb_out_dim = sum(e.embedding_dim for e in self.emb_layers)

    def forward(self, x_num, x_cat):
        x_emb = [emb(x_cat[:, i]) for i, emb in enumerate(self.emb_layers)]
        x_emb = torch.cat(tensors=x_emb, dim=1)
        x = torch.cat(tensors=[x_num, x_emb], dim=1)
        print(x)
        # return self.fc(x)

model = CensusClassifier(cat_cardinalities=cat_cardinalities)
x_num = torch.tensor(data=X_num_train, dtype=torch.float32)
x_cat = torch.tensor(data=X_cat_train, dtype=torch.long)
model(x_num=x_num, x_cat=x_cat)

tensor([[ 0.1485, -0.2167, -0.0354,  ...,  1.1485,  0.7259, -0.2937],
        [-0.1459, -0.2167, -2.2222,  ...,  1.1485,  0.7259, -0.2937],
        [-0.1459, -0.2167, -0.0354,  ...,  1.1485,  0.7259, -0.2937],
        ...,
        [-0.1459, -0.2167, -0.0354,  ...,  1.1485,  0.7259, -0.2937],
        [-0.1459, -0.2167, -1.6552,  ...,  1.1485,  0.7259, -0.2937],
        [ 1.8884, -0.2167, -0.0354,  ...,  1.1485,  0.7259, -0.2937]],
       grad_fn=<CatBackward0>)
