In [1]:
import ethnicolr
import pandas as pd
import polars as pl
import surgeo
import torch
import torch.nn as nn
import tqdm
from typing import Callable

from utils.constants import RACES, VALID_NAME_CHARS_DICT, VALID_NAME_CHARS_LEN, DEVICE
from models import FirstLastBiLSTM, FirstLastZctaBiLSTM
from utils.paths import FINAL_PATH
from utils.utils import prepare_name, encode_scalar

HIDDEN_SIZE = 128
RACES = list(RACES)


In [None]:
def prepare_surgeo_preds(df: pd.DataFrame, tag: str) -> pl.DataFrame:
    df["index"] = df.index
    df = df.rename(columns={"api": "asian"})
    df = df[["index"] + RACES]
    df["race"] = df[RACES].idxmax(axis=1)
    df.columns = [f"{tag}_{col}" for col in df.columns]
    df = df.rename(columns={f"{tag}_index": "index"})

    return pl.from_pandas(df)


def calculate_accuracies(df: pl.DataFrame) -> pl.DataFrame:
    for tag in ("model", "bifsg", "bisg", "eth"):
        df = df.with_columns(
            (pl.col("actual_race") == pl.col(f"{tag}_race")).alias(f"{tag}_correct")
        ).with_columns(
            (
                pl.col(f"{tag}_correct").sum().over("actual_race")
                / pl.col("num_in_race")
            ).alias(f"{tag}_correct_by_race")
        )

    return df


def get_output(model: nn.Module, tup: tuple) -> torch.Tensor:
    dataloader = DataLoader(
        dataset_cls(data),
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn,
    )


def flz_get_output(
    model: FirstLastZctaLSTM, tup: tuple[str, str, float]
) -> torch.Tensor:
    first_name, last_name, pct = tup
    name = prepare_name(
        first_name, last_name, VALID_NAME_CHARS_DICT, VALID_NAME_CHARS_LEN, DEVICE
    )
    pct = encode_scalar(pct, DEVICE)
    hidden = model.init_hidden()
    for i in range(name.size()[0]):
        output, hidden = model(name[i], pct, hidden)

    return output


def fl_get_output(model: FirstLastLSTM, tup: tuple[str, str]) -> torch.Tensor:
    first_name, last_name = tup
    name = prepare_name(
        first_name, last_name, VALID_NAME_CHARS_DICT, VALID_NAME_CHARS_LEN, DEVICE
    )
    hidden = model.init_hidden()
    for i in range(name.size()[0]):
        output, hidden = model(name[i], hidden)

    return output


def get_model_preds(
    model: nn.Module, data: pl.DataFrame, cols: list[str], get_output_func: Callable
) -> pl.DataFrame:
    with torch.no_grad():
        model.eval()

        test = data.to_dict()
        for race in RACES:
            test[race] = []
        test["model_race"] = []

        race_mapper = {i: r for i, r in enumerate(RACES)}

        data = list(
            zip(
                [test[col] for col in cols],
                strict=True,
            )
        )

        for d in tqdm.tqdm(data):
            output = get_output_func(d)
            percentages = nn.functional.softmax(output, dim=1)
            predicted = percentages.argmax().item()
            for idx, p in enumerate(percentages.tolist()[0]):
                test[race_mapper[idx]].append(p)
            test["model_race"].append(race_mapper[predicted])

    preds = pl.DataFrame(test).rename(
        {
            "race_ethnicity": "actual_race",
            "asian": "model_asian",
            "black": "model_black",
            "hispanic": "model_hispanic",
            "white": "model_white",
        }
    )

    return preds

In [None]:
test = pl.read_parquet(FINAL_PATH / "test_sample.parquet").sample(1_000_000)
test.write_parquet(FINAL_PATH / "to_test.parquet")

In [None]:
model = FirstLastZctaLSTM(VALID_NAME_CHARS_LEN, HIDDEN_SIZE, len(RACES)).to(DEVICE)
model.load_state_dict(
    torch.load(
        FINAL_PATH / "first_last_zcta/model2_10000.pth",
        map_location=torch.device("cpu"),
    )
)
flz = get_model_preds(
    model,
    data=test,
    cols=["first_name", "last_name", "pct_race_zcta"],
    get_output_func=flz_get_output,
)
flz.write_parquet(FINAL_PATH / "flz_preds.parquet")

model = FirstLastLSTM(VALID_NAME_CHARS_LEN, HIDDEN_SIZE, len(RACES)).to(DEVICE)
model.load_state_dict(
    torch.load(
        FINAL_PATH / "first_last/model2_10000.pth", map_location=torch.device("cpu")
    )
)
fl = get_model_preds(
    model, data=test, cols=["first_name", "last_name"], get_output_func=fl_get_output
)
fl.write_parquet(FINAL_PATH / "fl_pres.parquet")

In [None]:
# View accuracy
model_preds = (
    pl.scan_parquet(FINAL_PATH / "to_test.parquet")
    .rename(
        {
            "race_ethnicity": "actual_race",
            "asian": "model_asian",
            "black": "model_black",
            "hispanic": "model_hispanic",
            "white": "model_white",
        }
    )
    .with_columns(
        num_in_race=pl.col("index").count().over("actual_race"),
    )
    .collect()
)

# reset the index
model_preds = model_preds.with_columns(
    pl.Series(name="index", values=range(model_preds.shape[0]))
)

# save a pandas df since surgeo and ethnicolr use pandas
pd_df = model_preds.to_pandas()  # surgeo requires pandas

# Surgeo

bifsg = surgeo.BIFSGModel()
bisg = surgeo.SurgeoModel()

bifsg_preds = bifsg.get_probabilities(
    pd_df["first_name"], pd_df["last_name"], pd_df["zcta"]
)
model_preds = model_preds.join(prepare_surgeo_preds(bifsg_preds, "bifsg"), on="index")

bisg_preds = bisg.get_probabilities(pd_df["last_name"], pd_df["zcta"])
model_preds = model_preds.join(prepare_surgeo_preds(bisg_preds, "bisg"), on="index")

# ethnicolr

eth_pred = ethnicolr.pred_census_ln(
    pd_df[["last_name"]], lname_col="last_name", year=2010
)
eth_pred["last_name"] = eth_pred["last_name"].str.lower()
eth_pred = eth_pred.rename(columns={"api": "asian"})
eth_pred["race"] = eth_pred[RACES].idxmax(axis=1)
eth_pred.columns = [f"eth_{col}" for col in eth_pred.columns]
eth_pred = eth_pred.rename(columns={"eth_last_name": "last_name"})
model_preds = model_preds.join(pl.from_pandas(eth_pred), on="last_name", how="left")

# accuracy

model_preds = calculate_accuracies(model_preds)

# %%

print(
    model_preds.select(
        "actual_race",
        "model_correct_by_race",
        "bifsg_correct_by_race",
        "bisg_correct_by_race",
        "eth_correct_by_race",
        "num_in_race",
    )
    .unique("actual_race")
    .sort("actual_race")
)

# %%

"""
┌─────────────┬─────────────────┬─────────────────┬─────────────────┬────────────────┬─────────────┐
│ actual_race ┆ model_correct_b ┆ bifsg_correct_b ┆ bisg_correct_by ┆ eth_correct_by ┆ num_in_race │
│ ---         ┆ y_race          ┆ y_race          ┆ _race           ┆ _race          ┆ ---         │
│ str         ┆ ---             ┆ ---             ┆ ---             ┆ ---            ┆ u32         │
│             ┆ f64             ┆ f64             ┆ f64             ┆ f64            ┆             │
╞═════════════╪═════════════════╪═════════════════╪═════════════════╪════════════════╪═════════════╡
│ asian       ┆ 0.641205        ┆ 0.4275          ┆ 0.731739        ┆ 0.704155       ┆ 179523      │
│ hispanic    ┆ 0.808521        ┆ 0.630668        ┆ 0.785496        ┆ 0.89109        ┆ 1420826     │
│ black       ┆ 0.680265        ┆ 0.381849        ┆ 0.632189        ┆ 0.07932        ┆ 203947      │
│ white       ┆ 0.672028        ┆ 0.818344        ┆ 0.839589        ┆ 0.970754       ┆ 1195704     │
└─────────────┴─────────────────┴─────────────────┴─────────────────┴────────────────┴─────────────┘
"""
