In [1]:
import pandas as pd 
import numpy as np 
import seaborn as sns 
import matplotlib.pyplot as plt 
import os 
import sys 
import warnings 
import random
from pprint import pprint as pp
from dotenv import load_dotenv
import os
from huggingface_hub import whoami, HfFolder

import gc

import torch
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
from torch.utils.data.dataloader import DataLoader

from transformers import BitsAndBytesConfig
from transformers import DataCollatorForLanguageModeling
from transformers import AutoTokenizer, AutoModelForCausalLM 
from transformers import set_seed, Seq2SeqTrainer, LlamaTokenizer

from datasets import Dataset, DatasetDict

from peft import LoraConfig, get_peft_model


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
base_prompt = """You are a social media content moderator.
INSTRUCTION: The following is a social media message that needs to be classified with the label HATEFUL or NOT HATEFUL.
MESSAGE: {}
OUTPUT AND FORMAT: your output should be just the label."""


def log_hf():
    
    load_dotenv("env_vars.env")
    hf_token = os.environ.get("HF_ACCESS_TOKEN")
    HfFolder.save_token(hf_token)
    return print(whoami()["name"])

def format_prompt(text, base_prompt=base_prompt):

    formatted_prompt = base_prompt.format(text)
    
    return formatted_prompt

def translate_class_to_label(class_):

    translation_dict = {"not_hate": "NOT HATEFUL",
                        "explicit_hate": "HATEFUL",
                        "implicit_hate": "HATEFUL"}

    translated_label = translation_dict[class_]

    return translated_label

def format_message(formatted_prompt, label=True):
    if label:
        messages = [
            {"role": "system", "content": "You are a helpful assistant"},
            {"role": "user", "content": formatted_prompt},
            {"role": "assistant", "content": label}
        ]
    else:
        messages = [
            {"role": "system", "content": "You are a helpful assistant"},
            {"role": "user", "content": formatted_prompt}
        ]
    return messages

# def keytoken_weighted_loss(inputs, logits):
#     # Shift so that tokens < n predict n
#     shift_labels = inputs[..., 1:].contiguous()
#     shift_logits = logits[..., :-1, :].contiguous()
#     # Calculate per-token loss
#     loss_fct = CrossEntropyLoss(reduce=False)
#     loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
#     # Resize and average loss per sample
#     loss_per_sample = loss.view(shift_logits.size(0), shift_logits.size(1)).mean(axis=1)

#     return loss_per_sample


# Do I need to apply the chat template????????????????

# padding (`bool`, defaults to `False`):
#     Whether to pad sequences to the maximum length. Has no effect if tokenize is `False`.
# truncation (`bool`, defaults to `False`):
#     Whether to truncate sequences at the maximum length. Has no effect if tokenize is `False`.
# max_length (`int`, *optional*):
#     Maximum length (in tokens) to use for padding or truncation. Has no effect if tokenize is `False`. If
#     not specified, the tokenizer's `max_length` attribute will be used as a default.


def preprocess_and_tokenize(formatted_prompt, label=True, add_generation_prompt=False, context_length=512, output_messages_list=False):

    messages = format_message(formatted_prompt, label)
    tokenized = tokenizer.apply_chat_template(messages, 
                                                tokenize=True, 
                                                add_generation_prompt=add_generation_prompt,
                                                padding="max_length",
                                                truncation=True,
                                                max_length=context_length,
                                                return_dict=True,
                                                return_tensors="pt")
    if output_messages_list:
        return tokenized, messages
    
    return tokenized


In [5]:
warnings.filterwarnings("ignore") 
# log_hf()
load_dotenv("env_vars.env")

set_seed(42)
random.seed(42)
torch.manual_seed(42)
np.random.seed(42)


In [6]:
########################################################## DATA WORK
print("_________________________________")
print("Preapring the Data")


df = pd.read_csv("df_from_exp_to_imp.csv")


_________________________________
Preapring the Data


In [35]:
list_targets = list(df["target"].value_counts().index)

In [38]:
list_targets = list(set([target.lower() for target in list_targets])) # removing duplicates

In [40]:
len(list_targets)

346

In [28]:
"white liberals" in list_targets

True

In [None]:
# check Anti-LGBTQ organizations, Degenerates

In [None]:
translation_dictionary = {"immigrants":     ["Immigrants.", 
                                            "immigrants", 
                                            "Mexican immigratns",
                                            "mexicans and mestizos",
                                            "foreigners",
                                            "muslim immigrants",
                                            "irish and polish immigrants",
                                            "muslim refugees",
                                            "latino immigrants",
                                            "illegal aliens",
                                            "aliens",
                                            "outsiders"],

                                "political":      ["Progressives", 
                                                "Democrats", 
                                                "Republicans", 
                                                "Liberals", 
                                                "Right wing conservatives", 
                                                "Conservatives", 
                                                "white liberals", 
                                                "white liberal women", 
                                                "Progressive Indians",
                                                "religious conservatives",
                                                "alt-right people",
                                                "white progressives",
                                                "muslims, liberals, protesters",
                                                "black activists",
                                                "anti-gun activist.",
                                                "conservatives",
                                                "white conservatives",
                                                "conservative males",
                                                "left-wing liberals",
                                                "antifa",
                                                "politician",
                                                "male conservatives",
                                                "democratics",
                                                "nationalists",
                                                "lefties",
                                                "capitalists",
                                                "trump supporters",
                                                "liberals",
                                                "republicans",
                                                "leftists"],

                                "religious":        ["Jews", 
                                                "jewish",
                                                "Muslim people", 
                                                "Christians", 
                                                "Jewish People", 
                                                "Hinduists", 
                                                "Atheists", 
                                                "White people, christians", 
                                                "muslims", 
                                                "Islam people.",
                                                "religious people.",
                                                "christian leaders",
                                                "non-christian whites",
                                                "religious people",
                                                "muslims, liberals, protesters",
                                                "christians, jew",
                                                "muslim children",
                                                "muslim immigrants",
                                                "islamic followers",
                                                "non-christians",
                                                "pakistani,iranian,islamic,jewish",
                                                "white christians",
                                                "catholics",
                                                "atheists",
                                                "jews and arabs"],

                                "lbtq+":        ["homosexuals",
                                                "Gay folks",
                                                "transgeners",
                                                "lgbt people, specially transsexuals.",
                                                "gays",
                                                "lgbqti",
                                                "transexual",
                                                "gay people"],
                                                
                                "xenophobic":       ["muslim and black folks",
                                                "black people",
                                                "immigrants and other minority groups.",
                                                "mexicans and mestizos",
                                                "pakistanis",
                                                "minorities, mainly black people.",
                                                "italians",
                                                "muslim children",
                                                "foreigners",
                                                "irish and polish immigrants",
                                                "black and asian people.",
                                                "indian folks",
                                                "islamic followers",
                                                "pakistani,iranian,islamic,jewish",
                                                "hispanics",
                                                "indian women",
                                                "jews and arabs"
                                                ],
                                        
                                "other": ["not specified",
                                                "poor whites",
                                                "robert",
                                                "neo nazis and white supremacist",
                                                "white america",
                                                "rich people",
                                                "furries",
                                                "racists",
                                                "impossible to determine without more information or context"],

                                "race":         ["white people",
                                                "middle east people",
                                                "black activists",
                                                "mexicans and mestizos",
                                                "pakistanis",
                                                "minorities, mainly black people.",
                                                "color people",
                                                "black and asian people.",
                                                "indian folks",
                                                "pakistani,iranian,islamic,jewish",
                                                "hispanics",
                                                "indian women",
                                                "mixed race"
                                                ],

                                "misogyny": ["womens",
                                                "fat white women",
                                                "indian women",

                                    ]
                            }

In [62]:
translation_dictionary.keys()

dict_keys(['immigrants', 'political', 'religious', 'lbtq+', 'xenophobic', 'other', 'race', 'misogyny'])

In [64]:
str_ = ""
list_done = [v for l in translation_dictionary.values() for v in l]
for target in list_targets:
    if target not in list_done:
        str_ = str_ + target + "\n"
with open("list_targets.txt", "w") as f:
    f.write(str_)

In [65]:
len(list_done)

117

In [23]:
df["target"].value_counts().values > 100

array([ True,  True,  True,  True,  True,  True,  True, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False,

In [25]:
df["target"].value_counts().index[df["target"].value_counts().values > 100]

Index(['Minorities', 'Immigrants', 'white people', 'Muslims', 'Jews', 'Blacks',
       'non-white people'],
      dtype='object', name='target')

In [7]:
df

Unnamed: 0,clean_post,post,class,implicit_class,extra_implicit_class,target,implied_statement,split,time
0,university of washington ( seattle wa ) #proj...,university of washington ( seattle wa ) #proj...,not_hate,,,,,validation,2
1,no excuse not to - neoliberal media / establis...,: no excuse not to - neoliberal media / establ...,not_hate,,,,,train,2
2,human being is a white termanimals ... are wha...,human being is a white termanimals ... are wha...,explicit_hate,,,,,validation,1
3,youre aware that australians are white right ...,youre aware that australians are white right ...,not_hate,,,,,train,2
4,and those who are too stupid to realize or car...,and those who are too stupid to realize or car...,not_hate,,,,,train,1
...,...,...,...,...,...,...,...,...,...
17379,blog : 3 reasons hawaiis law forcing pregnancy...,blog : 3 reasons hawaiis law forcing pregnancy...,not_hate,,,,,train,2
17380,i see nothing wrong with pride in ones white r...,i see nothing wrong with pride in ones white r...,not_hate,,,,,train,2
17381,american conservatism managed to kill both bod...,american conservatism managed to kill both bod...,not_hate,,,,,validation,1
17382,last time i looked 100 percent of homicide ar...,last time i looked 100 percent of homicide ar...,not_hate,,,,,train,1


In [None]:

# ### Attaching the prompt to the clean post

df["formatted_prompt"] = df["clean_post"].apply(format_prompt)
df["label"] = df["class"].apply(translate_class_to_label)

# ### Turning the Df into a DatasetDict

t_1 = []
t_2 = []

for split in df["split"].unique():

    split_df_1 = df[(df["split"] == split) & (df["time"] == 1)]
    split_df_2 = df[(df["split"] == split) & (df["time"] == 2)]

    hf_split_1 = Dataset.from_pandas(split_df_1)
    hf_split_2 = Dataset.from_pandas(split_df_2)
    
    t_1.append(hf_split_1)
    t_2.append(hf_split_2)

hf_time_1 = DatasetDict({t_1[0]["split"][0]: t_1[0], 
                        t_1[1]["split"][0]: t_1[1],
                        t_1[2]["split"][0]: t_1[2]})

hf_time_2 = DatasetDict({t_2[0]["split"][0]: t_2[0], 
                        t_2[1]["split"][0]: t_2[1],
                        t_2[2]["split"][0]: t_2[2]})


In [None]:
########################################################## TOKENIZER WORK

model_id = "Models/TinyLlama"

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token


In [None]:

hf_time_1 = hf_time_1.map(preprocess_and_tokenize, input_columns=["formatted_prompt", "label"], batched=False)
hf_time_2 = hf_time_2.map(preprocess_and_tokenize, input_columns=["formatted_prompt", "label"], batched=False)

hf_time_1.set_format("torch")
hf_time_2.set_format("torch")

cols_to_remove = ["clean_post", "post", "class", "implicit_class", "extra_implicit_class", "target", "implied_statement", "split", "time", "formatted_prompt", "label", "__index_level_0__"]

for split in hf_time_1:
    if split != "test":
        hf_time_1[split] = hf_time_1[split].remove_columns(cols_to_remove)
        hf_time_2[split] = hf_time_2[split].remove_columns(cols_to_remove)

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

batch_size = 8

hf_time_1_train_loader = DataLoader(hf_time_1["train"], collate_fn=data_collator, batch_size=batch_size)
hf_time_1_validation_loader = DataLoader(hf_time_1["validation"], collate_fn=data_collator, batch_size=batch_size)
hf_time_1_test_loader = DataLoader(hf_time_1["test"], collate_fn=data_collator, batch_size=batch_size)

hf_time_2_train_loader = DataLoader(hf_time_2["train"], collate_fn=data_collator, batch_size=batch_size)
hf_time_2_validation_loader = DataLoader(hf_time_2["validation"], collate_fn=data_collator, batch_size=batch_size)
hf_time_2_test_loader = DataLoader(hf_time_2["test"], collate_fn=data_collator, batch_size=batch_size)

# ### So far, created the prompt, did the messages with the prompt and answer in place. Applied to chat template and tokenized 


In [None]:
########################3#################### MODEL WORK

print("_________________________________")
print("Loading the model and model config")

bnb_config = BitsAndBytesConfig(  
                                load_in_4bit= True,
                                bnb_4bit_quant_type= "nf4",
                                bnb_4bit_compute_dtype= torch.bfloat16,
                                bnb_4bit_use_double_quant= True,
                            )

model = AutoModelForCausalLM.from_pretrained(model_id,
                                            torch_dtype=torch.bfloat16,
                                            device_map="auto",
                                            quantization_config=bnb_config
                                            )

# to deal with the fact that we dont make the first token prediction??


model_size_before = sum(t.numel() for t in model.parameters())
print("Model Size before LoRA", model_size_before)
print(model)
print()

config = LoraConfig(
    r=8,
    lora_alpha=8,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    task_type="CAUSAL_LM",
    lora_dropout=0.1,
    bias="none",
)

model = get_peft_model(model, config)
print("Model After LoRA")
model.print_trainable_parameters()

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
model.to(device)


loss_fn = CrossEntropyLoss()
lr = 1e-5
optimizer = AdamW((param for param in model.parameters() if param.requires_grad), lr=lr)


In [None]:
################################################## TRAINING AND TESTING
n_epochs = 2

print("_________________________________")
print("Training the model")
print()

for epoch in range(n_epochs):

    torch.cuda.empty_cache()
    gc.collect()
    model.train()

    print("Epoch: ", epoch)
    losses = []

    for i, batch in enumerate(hf_time_1_train_loader):
        if i > 0:
            continue

        torch.cuda.empty_cache()
        gc.collect()

        print("\tBatch: ", i)
        # print(batch)
        batch.to(device)
        # print(batch.keys())
        # print(batch["input_ids"].shape)
        # print(batch["attention_mask"].shape)
        # print(batch["labels"].shape)


        batch = {k:torch.squeeze(v) for k,v in batch.items()}

        # print(batch["input_ids"].shape)
        # print(batch["attention_mask"].shape)
        # print(batch["labels"].shape)


        output = model(**batch)
        logits = output.logits
        loss = loss_fn(logits, batch["labels"])

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        losses.append(loss.detach().item())

        print(batch.keys())
        print(loss.detach().item())
        print(output.logits.shape)
        print(output.probas)

        if i > 3:
            continue

    epoch_loss = sum(losses)/len(hf_time_1_train_loader)
    print(f"Epoch {epoch} Loss: {epoch_loss}")

    model.eval()
    with torch.no_grad():  

        torch.cuda.empty_cache()
        gc.collect()

        val_losses = []

        for i, batch in enumerate(hf_time_1_validation_loader):
            if i > 0:
                continue
            batch.to(device)
            batch = {k:torch.squeeze(v) for k,v in batch.items()}

            output = model(**batch)
            logits = output.logits
            val_loss = loss_fn(logits, batch["labels"])

            val_losses.append(val_loss.detach().item())

        val_loss_epoch = sum(val_losses)/len(hf_time_1_validation_loader)
        print(f"Epoch {epoch} Validation Loss: {val_loss_epoch}")

print()
print("_________________________________")
print("Testing the model")
for i, test_batch in enumerate(hf_time_1["test"]):

    if i > 0:
        break
    
    text = test_batch["formatted_prompt"]
    tokenized_chat_template, messages_list = preprocess_and_tokenize(text, label=False, add_generation_prompt=True, output_messages_list=True)
    output = model.generate(**tokenized_chat_template.to(device))
    pred = tokenizer.decode(output[0], skip_special_tokens=True)
    
    print(text)
    print(tokenized_chat_template)
    print(output)
    print(pred)


In [None]:
print("CHECKING GENERATION")
print(messages_list)

print(tokenized_chat_template)
print(output)

print(type(output))
print(output.shape)

print("_________________________________")
print("Saving the model and Tokenizer")
model_name = model_id.split("/")[-1]
model.save_pretrained(f"alberto-lorente/{model_name}_test")
tokenizer.save_pretrained(f"alberto-lorente/{model_name}_test")

print("RUN SUCCESSFULLY")