In [4]:
import pandas as pd
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from aif360.datasets import BinaryLabelDataset
from aif360.metrics import ClassificationMetric
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import DataLoader, TensorDataset

In [None]:
df = pd.read_csv(r"bias_bio.csv", index_col=0)
texts = df['hard_text'].astype(str).tolist()
y = df['gender'].values
S = df['profession'].values

tok   = AutoTokenizer.from_pretrained('bert-base-uncased')
bert  = AutoModel.from_pretrained('bert-base-uncased').eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
bert.to(device)

def embed_texts(texts, batch_size=32, max_len=128):
    all_emb = []
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        enc   = tok(batch, padding=True, truncation=True, max_length=max_len, return_tensors='pt').to(device)
        with torch.no_grad():
            out  = bert(**enc).last_hidden_state
            mask = enc['attention_mask'].unsqueeze(-1)
            summed = (out * mask).sum(dim=1)
            counts = mask.sum(dim=1)
            emb    = (summed / counts).cpu()
        all_emb.append(emb)
    return torch.cat(all_emb, dim=0)

X = embed_texts(texts)

# 3) Define the LSTM class (must match your training code)
class OneStepLSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, num_classes):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
        self.fc   = nn.Linear(hidden_dim, num_classes)
    def forward(self, x):
        x = x.unsqueeze(1)
        out,(hn,_) = self.lstm(x)
        return self.fc(hn[-1])

# 4) Load pretrained “unfair” model
model_unfair = OneStepLSTM(X.size(1), 128, 1, 2).to(device)
chk = torch.load("lstm_gender_classifier.pth", map_location=device)
model_unfair.load_state_dict(chk['model_state_dict'])
model_unfair.eval()

# 5) Load pretrained “fair” model
model_fair = OneStepLSTM(X.size(1), 128, 1, 2).to(device)
chk = torch.load("fair_lstm_gender_classifier.pth", map_location=device)
model_fair.load_state_dict(chk['model_state_dict'])
model_fair.eval()

# 6) Batch‐predict helper
def get_preds(model, X, batch_size=64):
    preds = []
    loader = DataLoader(TensorDataset(X), batch_size=batch_size)
    with torch.no_grad():
        for xb, in loader:
            xb = xb.to(device)
            logits = model(xb)
            preds.extend(logits.argmax(dim=1).cpu().numpy())
    return np.array(preds)

pred_unfair = get_preds(model_unfair, X)
pred_fair   = get_preds(model_fair,   X)

# 7) Build AIF360 datasets
df_bld = pd.DataFrame({'gender': y, 'profession': S})
bld = BinaryLabelDataset(
    df=df_bld,
    label_names=['gender'],
    protected_attribute_names=['profession'],
    favorable_label=1,
    unfavorable_label=0
)

pred_unfair_bld = bld.copy()
pred_unfair_bld.labels = pred_unfair.reshape(-1,1)
pred_fair_bld   = bld.copy()
pred_fair_bld.labels   = pred_fair.reshape(-1,1)

# 8) Define privileged vs. unprivileged
priv_code = int(df_bld['profession'].value_counts().idxmax())
other     = [int(c) for c in df_bld['profession'].unique() if c!=priv_code]
privileged_groups   = [{'profession': priv_code}]
unprivileged_groups = [{'profession': c} for c in other]

# 9) Compute metrics
m_unfair = ClassificationMetric(
    bld, pred_unfair_bld,
    unprivileged_groups=unprivileged_groups,
    privileged_groups=privileged_groups
)
m_fair = ClassificationMetric(
    bld, pred_fair_bld,
    unprivileged_groups=unprivileged_groups,
    privileged_groups=privileged_groups
)

acc_u  = m_unfair.accuracy()
spd_u  = m_unfair.statistical_parity_difference()
aod_u  = m_unfair.average_odds_difference()

acc_f  = m_fair.accuracy()
spd_f  = m_fair.statistical_parity_difference()
aod_f  = m_fair.average_odds_difference()

# 10) Plot comparison
metrics = ['Accuracy','Statistical Parity Diff','Average Odds Diff']
vals_unfair = [acc_u, spd_u, aod_u]
vals_fair   = [acc_f, spd_f, aod_f]

x = np.arange(len(metrics))
width = 0.35

fig, ax = plt.subplots()
ax.bar(x - width/2, vals_unfair, width, label='Unfair')
ax.bar(x + width/2, vals_fair,   width, label='Fair')
ax.set_xticks(x)
ax.set_xticklabels(metrics)
ax.legend()
plt.xticks(rotation=15)
plt.tight_layout()
plt.show()
