# BERT
This notebook uses BERT for the bias classification.

## 1. Upload your dataset

Your dataset must be in excel format and must contain `sentence` column.

In [None]:
import panel as pn
pn.extension(notifications=True)

In [None]:
from causation.utils import fileuploader 

uploaded = dict()
sets = ['train', 'val', 'test']
for set_ in sets:
    finput, uploaded_data = fileuploader('.xlsx')
    uploaded[set_] = dict()
    uploaded[set_]['row'] = pn.Row(pn.pane.Str(f"{set_} set:".rjust(10)), finput)
    uploaded[set_]['finput'] = finput
    uploaded[set_]['upload'] = uploaded_data
    
pn.Column('# Upload datasets', *(uploaded[set_]['row'] for set_ in sets))

In [None]:
has_uploads = all(uploaded[set_]['upload'].get('data', False) for set_ in uploaded.keys())
if not has_uploads:
    pn.state.notifications.error('Did you upload all 3 datasets?', duration=10_000)
    raise Exception('Did you upload all 3 datasets?')
import pandas as pd
from atap_corpus.corpus import Corpus, Corpora
import spacy

corpora = Corpora([
    Corpus.from_dataframe(pd.read_excel(uploaded[set_]['upload'].get('data')), col_doc='sentence', name=set_) for set_ in sets
])

[(c.name, len(c)) for c in corpora.items()]

In [None]:
import torch
if torch.backends.mps.is_available():
    device = 'mps'
elif torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'
device

In [None]:
MODEL = 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext'

In [None]:
from torch.utils.data import Dataset
from transformers import AutoTokenizer
from transformers.tokenization_utils_base import BatchEncoding
import numpy as np

class GeneDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = torch.tensor(labels)

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = self.labels[idx]
        return item

    def __len__(self):
        return len(self.labels)


def create_gene_datasets(bert_model: str, train: tuple, test: tuple, val: tuple = None) -> dict[str, GeneDataset]:
    """ Transform text dataset to bert encodings dataset."""
    tokenizer = AutoTokenizer.from_pretrained(bert_model)
    if 'uncased' in bert_model: assert tokenizer.do_lower_case

    assert isinstance(train, tuple) and len(train) == 2, "Invalid data structure, use split_dataset to get tuples."
    assert isinstance(train[0], list) and isinstance(train[1],
                                                     np.ndarray), "Invalid data structure, use split_dataset to get tuples."

    X_train, y_train = train
    X_test, y_test = test

    train_encodings: BatchEncoding = tokenizer(X_train, truncation=True, padding=True, max_length=512)
    test_encodings: BatchEncoding = tokenizer(X_test, truncation=True, padding=True, max_length=512)

    datasets = {
        'train': GeneDataset(train_encodings, y_train),
        'test': GeneDataset(test_encodings, y_test)
    }

    if val:
        X_val, y_val = val
        val_encodings: BatchEncoding = tokenizer(X_val, truncation=True, padding=True, max_length=512)
        datasets.update({'val': GeneDataset(val_encodings, y_val)})
    return datasets



In [None]:
from transformers import Trainer
import evaluate
import numpy as np
import torch.nn as nn
import torch

# Define evaluation metric
metric = evaluate.load('f1')


def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)


class GeneTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels")
        # forward pass
        outputs = model(**inputs)
        logits = outputs.get("logits")
        # compute custom loss
        loss_fct = nn.CrossEntropyLoss(reduction='mean')
        loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels)
        return (loss, outputs) if return_outputs else loss

In [None]:
from datetime import datetime
import transformers
from transformers import TrainingArguments, AutoModelForSequenceClassification, EarlyStoppingCallback
from sklearn.metrics import jaccard_score
from sklearn.metrics import confusion_matrix, classification_report
import numpy as np

transformers.logging.set_verbosity(transformers.logging.ERROR)   # stop from_pretrained calls to output loading from config, weights... logs. But will load errors.

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = torch.round(torch.sigmoid(torch.tensor(logits)))
    predictions = torch.argmax(predictions, axis=1)
    return {'jaccard': jaccard_score(labels,predictions, average='weighted')}

trial = f"trial_{datetime.now().strftime('%d-%m-%YT%H:%M:%S')}"
model = AutoModelForSequenceClassification.from_pretrained(MODEL,
                                                       problem_type='single_label_classification').to(device)
args = TrainingArguments(
    output_dir='./.output/'+trial,
    evaluation_strategy='steps',
    save_strategy='steps',
    eval_steps=20,
    save_steps=20,
    num_train_epochs=40,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    logging_dir='./logs',
    logging_steps=10,
    load_best_model_at_end=True,
    disable_tqdm=False,
    log_level='error',
    use_mps_device= device == 'mps',
    seed=42,
    data_seed=42,
    optim='adamw_hf'
)

In [None]:
train = corpora.items()[0]
test = corpora.items()[-1]
val = corpora.items()[1]

results = dict()
classes = ['DE', 'SE', 'NA', 'HD']
for clz in classes:
    neutral = train.s.filter_by_item(name=clz, items=0)
    biased = train.s.filter_by_item(name=clz, items=1)
    balanced_train = biased.join(neutral.sample(len(biased), rand_stat=42))
    X_train = balanced_train.docs().tolist()
    y_train = np.array(balanced_train[clz].tolist())
    
    X_test = test.docs().tolist()
    y_test = np.array(test[clz].tolist())
    
    X_val = val.docs().tolist()
    y_val = np.array(val[clz].tolist())
    
    datasets = create_gene_datasets(bert_model=MODEL, train=(X_train, y_train), test=(X_test, y_test), val=(X_val, y_val))

    trainer = GeneTrainer(model=model, args=args, compute_metrics=compute_metrics,
                         train_dataset=datasets.get('train'), eval_dataset=datasets.get('val'),
                         callbacks=[EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=0.0)])
    trainer.train()
    outputs = trainer.predict(datasets.get('test'))
    preds, labels = outputs.predictions, outputs.label_ids
    pred = torch.argmax(torch.round(torch.sigmoid(torch.tensor(preds))), axis=1)
    results[clz] = dict()
    results[clz]['labels'] = labels
    results[clz]['preds'] = pred

In [None]:
import tempfile
from pathlib import Path

tmpd = Path(tempfile.mkdtemp())

In [None]:
from sklearn.metrics import classification_report
from typing import IO
import io

def evaluate(y_pred: np.ndarray, y_true: np.ndarray, file: IO, labels=None, **kwargs):
    assert y_pred.shape == y_true.shape, "Mismatched shape between y_pred and y_true."
    assert file.writable(), "File is not writable."
    report = classification_report(y_pred=y_pred, y_true=y_true, output_dict=False, labels=labels, **kwargs)
    file.write(report)
    
classifier = "BERT"
file = io.TextIOWrapper(io.BufferedWriter(io.FileIO(tmpd.joinpath(f"{classifier}.txt"), mode='w')), encoding='utf-8')
s = io.StringIO()
for clazz, res in results.items():
    file.write("===" + clazz + "===\n")
    s.write("===" + clazz + "===\n")
    evaluate(res['preds'], res['labels'], file=file)
    evaluate(res['preds'], res['labels'], file=s)
file.close()
s.seek(0)
print(s.read()); s.close()
print(classification_report(y_pred=pred, y_true=labels, output_dict=False))

In [None]:
def evaluate(y_pred: np.ndarray, y_true: np.ndarray, **kwargs):
    assert y_pred.shape == y_true.shape, "Mismatched shape between y_pred and y_true."
    report = classification_report(y_pred=y_pred, y_true=y_true, output_dict=True, **kwargs)
    return report

r_dfs = list()
for clazz, res in results.items():
    report = evaluate(res['preds'], res['labels'])
    r_df = pd.DataFrame.from_dict(report).T.loc[['0', '1'], ['precision', 'recall', 'f1-score']]
    r_dfs.append(r_df)

In [None]:
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np

# Select a colormap
cmap = cm.get_cmap('tab20c')

fig, axs = plt.subplots(1, 1, figsize=(12, 6))
categories = list(results.keys())
colors = cmap(np.linspace(0, 1, len(categories)))
              
values = []
for r_df in r_dfs:
    pre = r_df.loc['1', 'precision']
    rec = r_df.loc['1', 'recall']
    f1 = r_df.loc['1', 'f1-score']
    values.append([pre, rec, f1])
values = np.array(values)

for i, metric in enumerate(['precision', 'recall', 'f1-score']):
    plt.scatter(categories, values[:, i], color=colors[i], label=metric)

plt.title(classifier)
plt.xlabel('Bias', fontsize=12)
plt.grid(True)
plt.legend(loc='upper left', bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.savefig(tmpd.joinpath('plot.png'))
plt.show()

In [None]:
import zipfile
import os
from datetime import datetime
from pathlib import Path
import panel as pn

now = datetime.now().strftime(format="%Y-%m-%d_%H-%M-%S")
zfname = Path(f'{now}-{classifier}.zip')
file_names = list(tmpd.rglob("*"))
file_names += [u['upload']['data'] for u in uploaded.values()]
with zipfile.ZipFile(zfname, 'w') as zipf:
    for file_name in file_names:
        zipf.write(file_name, arcname=os.path.basename(file_name))
print(f"Saved as {zfname}.\nClick below to download.")

# download link for the zip.
pn.widgets.FileDownload(file=str(zfname), filename=zfname.name)