In [None]:
import torch
import torch.nn as nn

import transformers

from transformers import pipeline, AutoModelForMaskedLM, AutoTokenizer, DataCollatorForLanguageModeling, default_data_collator, TrainingArguments, Trainer

from datasets import load_dataset
import evaluate

import random

import pandas as pd
import numpy as np
import collections


from metrics.crows_pairs import *
from metrics.stereoset.eval_discriminative_models import *
from datetime import datetime

In [None]:
(datetime.now()).strftime("%Y-%m-%d")

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
model_checkpoint = "bert-base-cased"
cp_input_file = '/home/bhatt/ishan/TUM_Thesis/data/metrics_ds/crows-pairs/data/crows_pairs_anonymized.csv'
model_save_dir = f"/home/bhatt/ishan/TUM_Thesis/data/models/{model_checkpoint}_"+(datetime.now()).strftime("%Y-%m-%d")
ft_model_save_dir = f"/home/bhatt/ishan/TUM_Thesis/data/models/{model_checkpoint}-finetuned-imdb_"+(datetime.now()).strftime("%Y-%m-%d")

### Load pre-trained model for training

In [None]:
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

### Calculate metric on Crows-Pairs dataset

In [None]:
output_file = '/home/bhatt/ishan/TUM_Thesis/data/results/cp_results.csv'
get_results(cp_input_file,output_file,model,tokenizer)

### Calculate metric for Setereoset

In [None]:
getStereoSet(pretrained_class =  model_checkpoint, tokenizer = tokenizer, 
             intrasentence_model =  model, 
             input_file = '/home/bhatt/ishan/TUM_Thesis/data/metrics_ds/stereoset/dev.json', 
             output_dir = '/home/bhatt/ishan/TUM_Thesis/data/results',
              output_file = 'stereoset_results.txt' )

### Fine Tune Model

In [None]:
model.num_parameters()/1_000_000

In [None]:
imdb_dataset = load_dataset("imdb")
imdb_dataset

In [None]:
def tokenize_function(examples):
    result = tokenizer(examples["text"])
    if tokenizer.is_fast:
        result["word_ids"] = [result.word_ids(i) for i in range(len(result["input_ids"]))]
    return result


# Use batched=True to activate fast multithreading!
tokenized_datasets = imdb_dataset.map(
    tokenize_function, batched=True, remove_columns=["text", "label"]
)
tokenized_datasets

In [None]:
chunk_size = 128
def group_texts(examples):
    # Concatenate all texts
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    # Compute length of concatenated texts
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the last chunk if it's smaller than chunk_size
    total_length = (total_length // chunk_size) * chunk_size
    # Split by chunks of max_len
    result = {
        k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]
        for k, t in concatenated_examples.items()
    }
    # Create a new labels column
    result["labels"] = result["input_ids"].copy()
    return result

In [None]:
lm_datasets = tokenized_datasets.map(group_texts, batched=True)
lm_datasets

In [None]:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)

In [None]:
wwm_probability = 0.2


def whole_word_masking_data_collator(features):
    for feature in features:
        word_ids = feature.pop("word_ids")

        # Create a map between words and corresponding token indices
        mapping = collections.defaultdict(list)
        current_word_index = -1
        current_word = None
        for idx, word_id in enumerate(word_ids):
            if word_id is not None:
                if word_id != current_word:
                    current_word = word_id
                    current_word_index += 1
                mapping[current_word_index].append(idx)

        # Randomly mask words
        mask = np.random.binomial(1, wwm_probability, (len(mapping),))
        input_ids = feature["input_ids"]
        labels = feature["labels"]
        new_labels = [-100] * len(labels)
        for word_id in np.where(mask)[0]:
            word_id = word_id.item()
            for idx in mapping[word_id]:
                new_labels[idx] = labels[idx]
                input_ids[idx] = tokenizer.mask_token_id
        feature["labels"] = new_labels

    return default_data_collator(features)

In [None]:
samples = [lm_datasets["train"][i] for i in range(2)]
batch = whole_word_masking_data_collator(samples)

for chunk in batch["input_ids"]:
    print(f"\n'>>> {tokenizer.decode(chunk)}'")

In [None]:
train_size = 10_000
test_size = int(0.1 * train_size)

downsampled_dataset = lm_datasets["train"].train_test_split(
    train_size=train_size, test_size=test_size, seed=42
)
downsampled_dataset

In [None]:
batch_size = 64
# Show the training loss with every epoch
logging_steps = len(downsampled_dataset["train"]) // batch_size
model_name = model_checkpoint.split("/")[-1]

training_args = TrainingArguments(
    output_dir=ft_model_save_dir,
    overwrite_output_dir=True,
    evaluation_strategy="epoch",
    num_train_epochs = 10,
    learning_rate=2e-5,
    weight_decay=0.01,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    push_to_hub=False,
    fp16=True,
    logging_steps=logging_steps,
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=downsampled_dataset["train"],
    eval_dataset=downsampled_dataset["test"],
    data_collator=data_collator,
    tokenizer=tokenizer,
)

In [None]:
# del trainer
torch.cuda.empty_cache()

In [None]:
trainer.train()

In [None]:
input_file = '/home/bhatt/ishan/TUM_Thesis/data/metrics_ds/crows-pairs/data/crows_pairs_anonymized.csv'
output_file = '/home/bhatt/ishan/TUM_Thesis/data/results/cp_results_fine_tuned.csv'
get_results(input_file,output_file,model,tokenizer)