In [None]:
import sys
sys.path.insert(0,'..')

In [None]:
import os
import math
import argparse
import ruamel.yaml as yaml
import umap
import torch
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import numpy as np
import seaborn as sns
from transformers import AutoTokenizer
from tqdm.notebook import tqdm, trange
import pickle
from sklearn.manifold import TSNE
from torch.utils.data import TensorDataset, DataLoader, Dataset

from src.models.model_diff_modular import ModularDiffModel
from src.adv_attack import get_hidden_dataloader
from src.utils import get_num_labels, dict_to_device
from src.data_handler import get_data_loader_bios, read_label_file


In [None]:
def get_data_sorted(args_train):
    num_labels = get_num_labels(args_train.labels_task_path)
    num_labels_protected = get_num_labels(args_train.labels_protected_path)
    tokenizer = AutoTokenizer.from_pretrained(args_train.model_name)
    train_loader = get_data_loader_bios(
        tokenizer = tokenizer,
        data_path = args_train.train_pkl,
        labels_task_path = args_train.labels_task_path,
        labels_prot_path = args_train.labels_protected_path,
        batch_size = args_train.batch_size,
        max_length = 200,
        shuffle = False,
        debug = False
    )
    val_loader = get_data_loader_bios(
        tokenizer = tokenizer,
        data_path = args_train.val_pkl,
        labels_task_path = args_train.labels_task_path,
        labels_prot_path = args_train.labels_protected_path,
        batch_size = args_train.batch_size,
        max_length = 200,
        shuffle = False,
        debug = False
    )
    with open(args_train.train_pkl, 'rb') as file:
        data_dicts_train = pickle.load(file)
    with open(args_train.val_pkl, 'rb') as file:
        data_dicts_val = pickle.load(file)
    return train_loader, val_loader, num_labels, num_labels_protected, data_dicts_train, data_dicts_val

def get_embeddings(model, loader):
        model.eval()
        emb_list = []
        label_list_task = []
        label_list_prot = []
        for batch in tqdm(loader, desc="generating embeddings"):
            inputs, labels_task, labels_prot = batch
            inputs = dict_to_device(inputs, model.device)
            hidden = model._forward(**inputs)
            emb_list.append(hidden.cpu())
            label_list_task.append(labels_task)
            label_list_prot.append(labels_prot)
        return torch.cat(emb_list), torch.cat(label_list_task), torch.cat(label_list_prot)

In [None]:
output_folder = "data"

In [None]:
# with open("../cfg.yml", "r") as f:
#     cfg = yaml.safe_load(f)
# args_train = argparse.Namespace(**cfg["train_config"], **cfg["data_config_bios"], **cfg["model_config"])

# setattr(args_train, "train_pkl", "/share/cp/datasets/nlp/text_classification_bias/bios/only_task_balanced/train.pkl")

In [None]:
# train_loader, val_loader, num_labels, num_labels_protected, data_dicts_train, data_dicts_val = get_data_sorted(args_train)

In [None]:
# # cp_dir = "/share/home/lukash/checkpoints_bert_L4/seed0"
# # cp = "bert_uncased_L-4_H-256_A-4-fixmask0.1-modular.pt"
# cp_dir = "/share/home/lukash/checkpoints_bert_L4/seed{}"
# cp = "bert_uncased_L-4_H-256_A-4-fixmask0.1-modular-sparse_task.pt"
# cp_dir = "../checkpoints_bios"
# cp = "bert_uncased_L-4_H-256_A-4-modular_baseline-seed{}.pt"

# data = []
# for i in range(4):
#     filepath = os.path.join(cp_dir.format(i), cp)
#     model_biased = ModularDiffModel.load_checkpoint(filepath, remove_parametrizations=True, debiased=False)
#     model_debiased = ModularDiffModel.load_checkpoint(filepath, remove_parametrizations=True, debiased=True)
#     setattr(args_train, "model_name", model_biased.model_name)

#     model_biased.to("cuda:2")
#     model_debiased.to("cuda:2")

#     emb_train_biased, labels_train_task, labels_train_prot = get_embeddings(model_biased, train_loader)
#     emb_val_biased, labels_val_task, labels_val_prot = get_embeddings(model_biased, val_loader)
#     emb_train_debiased, _, _ = get_embeddings(model_debiased, train_loader)
#     emb_val_debiased, _, _ = get_embeddings(model_debiased, val_loader)
#     data.append([emb_train_biased, emb_train_debiased, emb_val_biased, emb_val_debiased, labels_train_task, labels_val_task, labels_train_prot, labels_val_prot])

#     model_biased.cpu()
#     model_debiased.cpu()

In [None]:
# with open(os.path.join(output_folder, "modular_diff_embeddings.pkl"), "wb") as f:
#     pickle.dump(data, f)

In [None]:
with open(os.path.join(output_folder, "modular_diff_embeddings.pkl"), "rb") as f:
    data = pickle.load(f)

In [None]:
emb_train_biased, emb_train_debiased, emb_val_biased, emb_val_debiased, labels_train_task, labels_val_task, labels_train_prot, labels_val_prot = [torch.cat(x) for x in list(zip(*data))]

In [None]:
i = 5
cp_idx_train = torch.kron(torch.arange(i), torch.full((emb_train_biased.shape[0] // i,), 1))
cp_idx_val = torch.kron(torch.arange(i), torch.full((emb_val_biased.shape[0] // i,), 1))

In [None]:
train_embeddings = torch.cat([emb_train_biased, emb_train_debiased])
val_embeddings = torch.cat([emb_val_biased, emb_val_debiased])

In [None]:
# umap_reducer = umap.UMAP()
# emb_umap = umap_reducer.fit_transform(torch.cat([train_embeddings, val_embeddings]).numpy())

In [None]:
# with open(os.path.join(output_folder, "modular_diff_embeddings_umap.pkl"), "wb") as f:
#     pickle.dump((emb_umap, umap_reducer), f)

In [None]:
with open(os.path.join(output_folder, "modular_diff_embeddings_umap.pkl"), "rb") as f:
    emb_umap, umap_reducer = pickle.load(f)

In [None]:
# tsne_reducer = TSNE()
# emb_tsne = tsne_reducer.fit_transform(torch.cat([train_embeddings, val_embeddings]).numpy())

In [None]:
emb_umap_train = emb_umap[:train_embeddings.shape[0]]
emb_umap_val = emb_umap[train_embeddings.shape[0]:]

In [None]:
cutoff_train = emb_train_biased.shape[0]
cutoff_val = emb_val_biased.shape[0]

In [None]:
job_dict = read_label_file("/share/cp/datasets/nlp/text_classification_bias/bios/labels_task.txt")
gender_dict = read_label_file("/share/cp/datasets/nlp/text_classification_bias/bios/labels_protected_gender.txt")

In [None]:
# c = np.concatenate([
#     np.zeros((emb_dict["biased_train"][0].shape[0],), int),
#     np.ones((emb_dict["debiased_train"][0].shape[0],), int)
# ], axis=0)
# labels = ["biased"] * emb_dict["biased_train"][0].shape[0] + ["debiased"] * emb_dict["debiased_train"][0].shape[0]

In [None]:
fig, axs = plt.subplots(2, 2, sharex=True, sharey=True, figsize=(10,10))
axs[0,0].scatter(x=emb_umap_train[:cutoff_train,0], y=emb_umap_train[:cutoff_train,1], c="blue", edgecolors='black')
axs[0,0].set_title("Training Data - Biased")
axs[0,1].scatter(x=emb_umap_train[cutoff_train:,0], y=emb_umap_train[cutoff_train:,1], c="orange", edgecolors='black')
axs[0,1].set_title("Training Data - Debiased")
axs[1,0].scatter(x=emb_umap_val[:cutoff_val,0], y=emb_umap_val[:cutoff_val,1], c="blue", edgecolors='black')
axs[1,0].set_title("Valdiation Data - Biased")
axs[1,1].scatter(x=emb_umap_val[cutoff_val:,0], y=emb_umap_val[cutoff_val:,1], c="orange", edgecolors='black')
axs[1,1].set_title("Validation Data - Debiased")
fig.suptitle('UMAP Embeddings')
fig.tight_layout()

In [None]:
# chart_dict = {}
# for job_title, job_idx in tqdm(job_dict.items()):

#     f_train = (labels_train_task == job_idx)
#     f_val = (labels_val_task == job_idx)
#     _train_embeddings = torch.cat([emb_train_biased[f_train], emb_train_debiased[f_train]])
#     _val_embeddings = torch.cat([emb_val_biased[f_val], emb_val_debiased[f_val]])

#     _umap_reducer = umap.UMAP()
#     _emb_umap = _umap_reducer.fit_transform(torch.cat([_train_embeddings, _val_embeddings]).numpy())
#     _emb_umap_train = _emb_umap[:_train_embeddings.shape[0]]
#     _emb_umap_val = _emb_umap[_train_embeddings.shape[0]:]

#     _c_train = ["blue" if x else "orange" for x in labels_train_prot[f_train]]
#     _c_val = ["blue" if x else "orange" for x in labels_val_prot[f_val]]

#     _cutoff_train = f_train.sum().item()
#     _cutoff_val = f_val.sum().item()

#     _emb_umap_train_biased = _emb_umap_train[:_cutoff_train]
#     _emb_umap_train_debiased = _emb_umap_train[_cutoff_train:]
#     _emb_umap_val_biased = _emb_umap_val[:_cutoff_val]
#     _emb_umap_val_debiased = _emb_umap_val[_cutoff_val:]

#     chart_dict[job_idx] = (_emb_umap_train_biased, _emb_umap_train_debiased, _c_train, _emb_umap_val_biased, _emb_umap_val_debiased, _c_val)

In [None]:
# with open("job_emb_dict.pkl", "wb") as f:
#     pickle.dump(chart_dict, f)

In [None]:
with open(os.path.join(output_folder, "job_emb_dict.pkl"), "rb") as f:
    chart_dict = pickle.load(f)

In [None]:
for job_idx, data in chart_dict.items():
    job_title = {v:k for k,v in job_dict.items()}[job_idx]

    _emb_umap_train_biased, _emb_umap_train_debiased, _c_train, _emb_umap_val_biased, _emb_umap_val_debiased, _c_val = data

    l = [Line2D([0], [0], marker='o', color=c) for c in ["orange", "blue"]]

    fig, axs = plt.subplots(2, 2, sharex=True, sharey=True, figsize=(10,10))
    axs[0,0].scatter(x=_emb_umap_train_biased[:,0], y=_emb_umap_train_biased[:,1], c=_c_train, edgecolors='black')
    axs[0,0].set_title("Training Data - Biased")
    axs[0,1].scatter(x=_emb_umap_train_debiased[:,0], y=_emb_umap_train_debiased[:,1], c=_c_train, edgecolors='black')
    axs[0,1].set_title("Training Data - Debiased")
    axs[1,0].scatter(x=_emb_umap_val_biased[:,0], y=_emb_umap_val_biased[:,1], c=_c_val, edgecolors='black')
    axs[1,0].set_title("Valdiation Data - Biased")
    axs[1,1].scatter(x=_emb_umap_val_debiased[:,0], y=_emb_umap_val_debiased[:,1], c=_c_val, edgecolors='black')
    axs[1,1].set_title("Validation Data - Debiased")
    fig.legend(l, ["Male", "Female"])
    fig.suptitle(f'UMAP Embeddings - {" ".join([x.capitalize() for x in job_title.split("_")])}')
    fig.tight_layout()
    plt.savefig(os.path.join(output_folder, f"umap_{job_title}.png"))
    plt.show()

In [None]:
# for job_title, job_idx in job_dict.items():

#     emb_train_biased = emb_umap_train[:cutoff_train][(labels_train_task == job_idx)]
#     emb_train_debiased = emb_umap_train[cutoff_train:][(labels_train_task == job_idx)]
#     emb_val_biased = emb_umap_val[:cutoff_val][(labels_val_task == job_idx)]
#     emb_val_debiased = emb_umap_val[cutoff_val:][(labels_val_task == job_idx)]
#     c_train = ["blue" if x else "orange" for x in labels_train_prot[(labels_train_task == job_idx)]]
#     c_val = ["blue" if x else "orange" for x in labels_val_prot[(labels_val_task == job_idx)]]

#     fig, axs = plt.subplots(2, 2, sharex=True, sharey=True, figsize=(10,10))
#     axs[0,0].scatter(x=emb_train_biased[:,0], y=emb_train_biased[:,1], c=c_train, edgecolors='black')
#     axs[0,0].set_title("Training Data - Biased")
#     axs[0,1].scatter(x=emb_train_debiased[:,0], y=emb_train_debiased[:,1], c=c_train, edgecolors='black')
#     axs[0,1].set_title("Training Data - Debiased")
#     axs[1,0].scatter(x=emb_val_biased[:,0], y=emb_val_biased[:,1], c=c_val, edgecolors='black')
#     axs[1,0].set_title("Valdiation Data - Biased")
#     axs[1,1].scatter(x=emb_val_debiased[:,0], y=emb_val_debiased[:,1], c=c_val, edgecolors='black')
#     axs[1,1].set_title("Validation Data - Debiased")
#     fig.suptitle(f'UMAP Embeddings - {job_title}')
#     fig.tight_layout()

In [None]:
distances = torch.norm(emb_train_biased - emb_train_debiased, dim=1)

In [None]:
v, idx = torch.topk(distances, k=10)

In [None]:
dicts = [data_dicts_train[i] for i in idx]

In [None]:
# [(d["title"], d["gender"], d["bio"]) for d in dicts]

In [None]:
sh = emb_train_biased.shape[1]
linear_transform = torch.nn.Linear(sh,sh)
loss = torch.nn.MSELoss(reduction="none")
opt = torch.optim.SGD(linear_transform.parameters(), lr=1e-4)

In [None]:
ds = TensorDataset(emb_train_biased, emb_train_debiased)
loader = DataLoader(ds, shuffle=True, batch_size=32, drop_last=False)

In [None]:
linear_transform.train()
s = "loss: {:.2f}"
losses = []
train_iterator = trange(20, desc=s.format(math.nan), position=0, leave=False)
for epoch in train_iterator:
    l_list = []
    for x,y in tqdm(loader, position=1, leave=False):
        y_hat = linear_transform(x)
        l = loss(y_hat, y)
        l.mean().backward()
        opt.step()
        opt.zero_grad()
        l_list.append(l.sum(1).detach())
    avg_loss = torch.cat(l_list).mean()
    losses.append(avg_loss)
    train_iterator.set_description(s.format(avg_loss), refresh=True)

In [None]:
avg_loss