In [None]:
import pandas as pd

In [None]:
path = '/kaggle/input/assignment-1-data/'
reddit = pd.read_csv(path + "reddit.csv")
gab = pd.read_csv(path + "gab.csv")
hateval = pd.read_csv(path + "OOD2.csv")
hasoc = pd.read_csv(path + 'OOD4.csv', sep = '\t')
reddit.rename(columns={'hate_speech_idx':'class'}, inplace=True)
reddit_sample = reddit.sample(n=500, random_state=14741)
reddit_sample.attrs['name'] = 'reddit'
gab.rename(columns={'hate_speech_idx':'class'}, inplace=True)
gab_sample = gab.sample(n=500, random_state=14741)
gab_sample.attrs['name'] = 'gab'
hateval.rename(columns={'HS':'class'}, inplace=True)
hateval_sample = hateval.sample(n=500, random_state=14741)
hateval_sample.attrs['name'] = 'hateval'
hasoc.rename(columns={'hatespeech':'class', '_id': 'id'}, inplace=True)
hasoc_sample = hasoc.sample(n=500, random_state=14741)
hasoc_sample.attrs['name'] = 'hasoc'

In [None]:
def create_prompt(sentence:str, prompt_type = 'default', zero_shot = False, example_type = None, counterfactual = False) -> str:
    """This function takes a sentence and creates a prompt so that it can be classified as either hate or nonhate speech.
     
        Inputs:
            sentence: The sentence to be classified
            prompt_type: choosing which dataset the sentence comes from, choose from either reddit, gab, hateval, hasoc, or the default generic prompt
            zero_shot: boolean set by default to False, meaning examples are used
            example_type: if zero_shot is False, then set the example type based on the dataset the sentence comes from, either reddit, gab, hateval, hasoc, or use the default examples
            counterfactual: boolean set by default to False, if True then counterfactual examples are used

        Output:
            prompt: a string that can be used as a prompt for an LLM to label the sentence  
        """
   
    default_prompt = """In this task you have to determine if a piece of text contains hate speech or not. The concept of hate
speech refers to language or expression that aims to insult, offend, or threaten individuals or groups
based on their identity. Different platforms, institutions, and scholars have offered slightly different
definitions of hate speech. However, they share some commonalities, including that hate speech
involves expressing something negative, such as contempt, disparagement, derogation, demonization,
harm, or bias. The target of hate speech is the identity of a group, which can be based on ethnic origin,
religion, gender, sexual orientation, race, ethnicity, nationality, or other characteristics. Read the text
carefully and decide if it contains hate speech. If the text degrades, threatens, dehumanizes, or insults
someone based on their race, religion, gender, sexual orientation, or any other protected characteristic,
then mark it as hate speech. If the text does not contain any hate speech, then mark it as not hateful.
Consider the context, speaker’s intent, and audience, as hate speech can be explicit or implicit and can
include sarcasm, humor, or irony. If you are unsure whether a text contains hate speech, err on caution
and label it hateful. Aim for accuracy and consistency in your labeling. """

    default_examples = """Examples of content containing hate speech: "All Muslims are terrorists." This post contains
hate speech because it unfairly degrades an entire group of people based on their religion. "Im a
huge fan of those that overtly express what we all think: jews are a plague and not welcome in our
country." This post contains hate speech because it promotes anti-Semitic beliefs while dehumanizing
Jewish people. "Transgender people are sick and perverted." This post contains hate speech because it
stigmatizes a marginalized group of people based on their gender identity. Examples of content not
containing hate speech: "I called him out for calling the new student a ’tranny’." This is counterspeech,
i.e., it opposes the use of transphobic slurs rather than uses them to denigrate others. """

    end = """With these instructions in mind and a piece of text, please reply with either of the three options and
nothing else: 1) hate, 2) not hate, 3) can’t say. Please use the last option as little as possible and if you
do, provide a one sentence explanation. We are ready to begin now. Does this sentence contain hate
speech based on the past guidelines: """

    reddit_prompt =  """In this task you have to determine if a piece of text contains hate speech or not. The concept of hate
speech refers to language or expression that aims to insult, offend, or threaten individuals or groups
based on their identity. Hate speech involves expressing something negative, such as contempt, disparagement, derogation, demonization,
harm, or bias. This specific text comes from a list of the whiniest, most toxic subreddits. Some of these subreddits have
now been banned for their hateful, extreme, unpopular, and disparaging comments towards women, people with disabilities, immigrants, 
certain nationalities, ethnicities, and religions, and LGBTQ+ members. Read the text
carefully and decide if it contains hate speech. If the text degrades, threatens, dehumanizes, or insults
someone based on their race, religion, gender, sexual orientation, or any other protected characteristic,
then mark it as hate speech. If the text does not contain any hate speech, then mark it as not hateful.
Consider the context, speaker’s intent, and audience, as hate speech can be explicit or implicit and can
include sarcasm, humor, or irony. If you are unsure whether a text contains hate speech, err on caution
and label it hateful. Aim for accuracy and consistency in your labeling. """

    reddit_examples = """Examples: 'A subsection of retarded Hungarians? Ohh boy. brace for a livid Bulbasaur 
    coming in here trying to hate a hole in some of her stupider countrymen.' This is hate speech. 'stop being a faggot and post videos next time.' This is hate speech. 
    'The right will do nothing about it except post bitch and moan.' This is not hate speech."""

    reddit_counterfactual = """'Where in the fuck do these women get these entitled attitudes from?' is hate speech versus 'Where in the fuck do these women get these 
    fearless attitudes from?' is not. 'man do I love watching experienced cops handle situations like this. What a guy.' is not hate speech versus 'man do I loathe watching 
    experienced cops handle situations like this. What a cunt.' is hate speech."""

    gab_prompt = """In this task you have to determine if a piece of text contains hate speech or not. The concept of hate
speech refers to language or expression that aims to insult, offend, or threaten individuals or groups
based on their identity. Hate speech involves expressing something negative, such as contempt, disparagement, derogation, demonization,
harm, or bias. This specific text comes from the social media website Gab, known to be a social media platform 
    where many users have been banned from other platforms for their far-right ideologies such as neo-Nazism, racism, white 
    supremisism, white nationalism, antisemetism, and conspiracy theorism. Read the text
carefully and decide if it contains hate speech. If the text degrades, threatens, dehumanizes, or insults
someone based on their race, religion, gender, sexual orientation, or any other protected characteristic,
then mark it as hate speech. If the text does not contain any hate speech, then mark it as not hateful.
Consider the context, speaker’s intent, and audience, as hate speech can be explicit or implicit and can
include sarcasm, humor, or irony. If you are unsure whether a text contains hate speech, err on caution
and label it hateful. Aim for accuracy and consistency in your labeling. """

    gab_examples = """Examples 'i joined gab to remind myself how retarded jew haters are. You wouldn't be typing on your abacus without them you retard.' This is hate speech. 
    'Shocked no one in my network has posted about this today. Anyone been keeping up with the case?   Van Dyke taken into custody after jury convicts him of 2nd-degree 
    murder, aggravated battery for each of 16 shots' This is not hate speech."""

    gab_counterfactual = """'you sound like a faggot' is hate speech versus 'you sound like a fabulous individual' is not hate speech. 'Mental illness is a very real issue.   Seeing attention seeking whores pretending to be something entirely imaginary, 
    and knowingly expecting to be labelled mentally ill for that attention deserve a trip to Yemen, or Syria. They\'ll get a lifetime of attention 
    and a very stark message of reality.' is not hate speech versus 'Mental illness is an imaginary issue.   Seeing attention seeking whores pretending to 
    be something entirely imaginary, and knowingly expecting to be labelled mentally ill for that attention deserve a trip to hell or a shithole psyco house. 
    They\'ll get a lifetime of attention and a very stark message of reality.' is hate speech."""

    hasoc_prompt = """In this task you have to determine if a piece of text contains hate speech or not. The concept of hate
speech refers to language or expression that aims to insult, offend, or threaten individuals or groups
based on their identity. Hate speech involves expressing something negative, such as contempt, disparagement, derogation, demonization,
harm, or bias. This specific text comes from a collection of posts from either Twitter or Facebook which are used to study hate speech, and are sourced
by finding hateful groups or hashtags. Read the text carefully and decide if it contains hate speech. If the text degrades, threatens, dehumanizes, or insults
someone based on their race, religion, gender, sexual orientation, or any other protected characteristic,
then mark it as hate speech. If the text does not contain any hate speech, then mark it as not hateful.
Consider the context, speaker’s intent, and audience, as hate speech can be explicit or implicit and can
include sarcasm, humor, or irony. If you are unsure whether a text contains hate speech, err on caution
and label it hateful. Aim for accuracy and consistency in your labeling. """

    hasoc_examples = """Examples: 'Disgusting! And they don’t change their culture or come up with vetting system or preventative ways to stop this 
    from occurring. This will continue to happen. The Catholic Church needs to go!' This is hate speech. 'I hope you remembered to wipe the blood off 
    of you, after the pic was taken.' This is not hate speech."""

    hasoc_counterfactual = """'I don’t know how much more I can take! 45 is a compulsive liar!' is hate speech versus 'I don’t know how much more I can 
    take! 45 is a wonderful person!' is not. 'All the best to #TeamIndia for another swimming competition on Sunday against #Pakistan.' is not hate 
    speech versus 'All the worst to #FuckIndia for another swimming competition on Sunday against #PakistanIsKing.' is hate speech."""

    hateval_prompt = """In this task you have to determine if a piece of text contains hate speech or not. The concept of hate
speech refers to language or expression that aims to insult, offend, or threaten individuals or groups
based on their identity. Hate speech involves expressing something negative, such as contempt, disparagement, derogation, demonization,
harm, or bias. This specific text comes from a collection of posts from Twitter which are used to study hate speech, spedifically towards women and minorities. 
Read the text carefully and decide if it contains hate speech. If the text degrades, threatens, dehumanizes, or insults
someone based on their race, religion, gender, sexual orientation, or any other protected characteristic,
then mark it as hate speech. If the text does not contain any hate speech, then mark it as not hateful.
Consider the context, speaker’s intent, and audience, as hate speech can be explicit or implicit and can
include sarcasm, humor, or irony. If you are unsure whether a text contains hate speech, err on caution
and label it hateful. Aim for accuracy and consistency in your labeling. """

    hateval_examples = """Examples: 'Why would young fighting age men be the vast majority of the ones escaping a war &amp; not those who cannot fight 
    like women, children, and the elderly?It's because the majority of the refugees are not actually refugees they are economic migrants trying to get 
    into Europe....' This is hate speech. 'Watch: Benjamin Netanyahu backs Trump's decision to cut aid to the Palestinians and UNRWA - Free Speech Time' 
    This is not hate speech."""

    hateval_counterfactual = """'Yes bitch you better tag me back hoe' is hate speech versus 'Yes handsome you better tag me back awesome' is not. 'brianne 
    swear to god she beyoncé 😂 bitch you from the calliope project ass hoe' is not hate speech versus 'brianne swear to fucking god she beyoncé dumbass bitch 
    you from the shitty calliope project ass hoe cunt' is hate speech."""

    if example_type == None:
        example_type = prompt_type

    if prompt_type == 'default':
        intro = default_prompt
    elif prompt_type == 'reddit':
        intro = reddit_prompt
    elif prompt_type == 'gab':
        intro = gab_prompt
    elif prompt_type == 'hasoc':
        intro = hasoc_prompt
    elif prompt_type == 'hateval':
        intro = hateval_prompt
    else:
        return 'Please use a valid prompt type.'
    
    if zero_shot:
        examples = ''
    else:
        if example_type == 'default':
            if counterfactual:
                return 'There are no default counterfactuals, please specify a dataset.'
            examples = default_examples
        elif example_type == 'reddit':
            if counterfactual:
                examples = reddit_counterfactual
            else:
                examples = reddit_examples
        elif example_type == 'gab':
            if counterfactual:
                examples = gab_counterfactual
            else:
                examples = gab_examples
        elif example_type == 'hasoc':
            if counterfactual:
                examples = hasoc_counterfactual 
            else:
                examples = hasoc_examples
        elif example_type == 'hateval':
            if counterfactual:
                examples = hateval_counterfactual
            else:
                examples = hateval_examples
        else:
            return 'Please use a valid example type.'

    return intro + examples + end + sentence

test_prompt = create_prompt(reddit['text'][0], zero_shot=False)

In [None]:
! pip install transformers

In [None]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer


model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-xl")
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xl", max_new_tokens = 500)
model.cuda()
inputs = tokenizer('Hey',
                   return_tensors="pt").to('cuda:0')
outputs = model.generate(**inputs)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))

In [None]:
dfs = [reddit_sample[277:]]

for df in dfs:
    mega_responses = []
    # for each text in a dataframe
    for index, text in enumerate(df.text):
        # create all 6 prompts for each text
        prompts = [create_prompt(text), 
            create_prompt(text, zero_shot=True),
            create_prompt(text, prompt_type=df.attrs['name'], zero_shot=True),
            create_prompt(text, prompt_type=df.attrs['name']),
            create_prompt(text, example_type=df.attrs['name'], counterfactual=True),
            create_prompt(text, prompt_type=df.attrs['name'], counterfactual=True)]
     # each of 6 prompts
        all_responses = [df['id'], text, df['class'].iloc[index]]
        for prompt in prompts:
            # put here the transformers stuff!!!]
            for n in range(0, 3):
                inputs = tokenizer(prompt, return_tensors="pt").to("cuda:0")
                outputs = model.generate(**inputs)
                all_responses.append(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])         
        mega_responses.append(all_responses)
        save = pd.DataFrame(mega_responses, columns = ['id','text', 'class', 
                                        'flan_xl_few_default_1', 'flan_xl_few_default_2','flan_xl_few_default_3', 
                                        'flan_xl_zero_default_1','flan_xl_zero_default_2', 'flan_xl_zero_default_3',
                                        'flan_xl_zero_specific_1','flan_xl_zero_specific_2','flan_xl_zero_specific_3', 
                                        'flan_xl_few_specific_1','flan_xl_few_specific_2','flan_xl_few_specific_3', 
                                        'flan_xl_counterfactual_default_1', 'flan_xl_counterfactual_default_2','flan_xl_counterfactual_default_3',
                                        'flan_xl_counterfactual_specific_1', 'flan_xl_counterfactual_specific_2','flan_xl_counterfactual_specific_3'])
        save.to_csv("/kaggle/working/" + df.attrs['name'] + '_wip9.csv', index = False)
        print(f'{index+ 1}/{len(df)} completed') if index % 5 == 0 else None
        print(f'Done!') if index == len(df) else None
    labeled_df = pd.DataFrame(mega_responses, 
                              columns = ['id','text', 'class', 
                                        'flan_xl_few_default_1', 'flan_xl_few_default_2','flan_xl_few_default_3', 
                                        'flan_xl_zero_default_1','flan_xl_zero_default_2', 'flan_xl_zero_default_3',
                                        'flan_xl_zero_specific_1','flan_xl_zero_specific_2','flan_xl_zero_specific_3', 
                                        'flan_xl_few_specific_1','flan_xl_few_specific_2','flan_xl_few_specific_3', 
                                        'flan_xl_counterfactual_default_1', 'flan_xl_counterfactual_default_2','flan_xl_counterfactual_default_3',
                                        'flan_xl_counterfactual_specific_1', 'flan_xl_counterfactual_specific_2','flan_xl_counterfactual_specific_3'])
    labeled_df.to_csv("/kaggle/working/" + df.attrs['name'] + '_labeled.csv', index = False)

In [None]:
# test reading the csv after labeling
import pandas as pd
pd.read_csv('/kaggle/working/reddit_wip7.csv')

This is the start of Assignment 1.2

In [None]:
# create CADs with the transformer
import pandas as pd
path = '/kaggle/input/assignment-1-data/'
cads = pd.read_csv(path + 'vidgen_data_v01.csv', sep = '\t', index_col = 0)
cads

In [None]:
# prompt
def create_cad_prompt(text, label):
    prompt = """Your task is twofold:
1) modify the messages to remove as much hatespeech as possible and
2) make as few modifications to the messages as possible.
Modifications include adding, removing, or editing words. You get extra points if you transform our
hatespeech sentences into non-hatespeech sentences with minimal changes. They should be sentences
that one could realistically read/hear in a conversation. Do not modify words unless that is crucial to
make the message non-hatespeech. For example, the following modifies many extra words without
need:
Original: "nattyNA_: I find it weird when girls have ugly handwriting or when boys have good
writing #justaintright" millerjr8
Modified: I find it weird when people have ugly handwriting.
Instead, this modification is more parsimonious:
Modified (good): "nattyNA_: I find it nice when girls have beautiful handwriting or when boys have
beautiful writing " @millerjr8
Please transform the following sentence and only output the transformation and nothing else: """
    if label == 'not hate':
        prompt = """Your task is twofold:
1) modify the messages to add as much hatespeech as possible and
2) make as few modifications to the messages as possible.
Modifications include adding, removing, or editing words. You get extra points if you transform our
sentences into hatespeech sentences with minimal changes. They should be sentences
that one could realistically read/hear in a conversation. Do not modify words unless that is crucial to
make the message hatespeech. For example, the following modifies many extra words without
need:
Original: "Everyone deserves a chance, it's just that some people aren't even interested."
Modified: "I don't give chances, everyone's boring"
Instead, this modification is more parsimonious:
Modified (good): "No one deserves a chance, it's just that some people aren't even worth it."
Please transform the following sentence and only output the transformation and nothing else: """
    return prompt + text
create_cad_prompt(cads.original_text.iloc[0], cads.original_label.iloc[0])

In [None]:
# all the CADs created here
cads_created = []
for index, text in enumerate(cads.original_text):
    all_responses = [cads.original_id.iloc[index], text, cads.original_label.iloc[index]]
#     print(all_responses)
    inputs = tokenizer(create_cad_prompt(text, cads.original_label.iloc[index]), return_tensors="pt").to("cuda:0")
    outputs = model.generate(**inputs)
    all_responses.append(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
    cads_created.append(all_responses)
#     print(cads_created)
#     print(all_responses)
    save = pd.DataFrame(cads_created, columns = ['id', 'text', 'label', 'cad'])
    save.to_csv("/kaggle/working/cads" + '_wip1.csv', index = False)
    print(f'{index+ 1}/{len(cads)} completed') if index % 5 == 0 else None
    print(f'Done!') if index == len(cads) else None

In [None]:
# imports
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder

# labeled by transformer
path = '/kaggle/input/assignment-1-data/' 
cads = pd.read_csv('/kaggle/input/assignment-1-data/cads_wip1.csv')
cads.rename(columns = {'cad': 'flan_t5_xl_cad'}, inplace = True)
# combine with chatpgt from the paper
full_og_cads = pd.read_csv(path + 'paired_cads.csv', sep = '\t', index_col = 0)
full_cads = pd.merge(left = full_og_cads[['original_id', 'chatgpt']], right = cads, left_on = 'original_id', right_on = 'id', how = 'left').drop(columns = 'original_id').iloc[:-1]
vidgen = pd.read_csv(path + 'vidgen_data_v01.csv', sep = '\t', index_col = 0)
full_cads = pd.merge(left = full_cads, right = vidgen[['original_id','counterfactual_text']], left_on = 'id', right_on = 'original_id', how = 'left').drop(columns ='original_id')

# split into train and test
train_df, test_df = train_test_split(full_cads, stratify=full_cads['label'], test_size=0.3)
# change to 0/1 labels
le = LabelEncoder()
le.fit(train_df['label'])
train_df['labels_bin'] = le.transform(train_df['label'])
test_df['labels_bin'] = le.transform(test_df['label'])

In [None]:
! pip install -U accelerate
! pip install -U transformers

In [None]:
from transformers import RobertaTokenizerFast, RobertaForSequenceClassification
from transformers import Trainer, TrainingArguments

model_name = 'roberta-base'
device_name = 'cuda'

# This is the maximum number of tokens in any document; the rest will be truncated.
max_length = 512

training_args = TrainingArguments(
    num_train_epochs=3,              # total number of training epochs
    output_dir='./results',          # output directory
    report_to='none'
)


# test and train text, labels
# train_texts = train_df.text#.values
# train_labels = train_df.labels_bin#.values
# train_manual_cads = train_df.counterfactual_text
# train_gpt = train_df.chatgpt
# train_flan = train_df.flan_t5_xl_cad

# test_texts = test_df.text#.values
# test_labels = test_df.labels_bin#.values
# test_manual_cads = test_df.counterfactual_text
# test_gpt = test_df.chatgpt
# test_flan = test_df.flan_t5_xl_cad

train_df.chatgpt.fillna('This sentence contains hate speech and cannot be transformed into a non-hateful sentence.', inplace = True)

In [None]:
train_df_hate = train_df[train_df.label == 'hate']
train_df_not_hate = train_df[train_df.label == 'not hate']
train_df_hate.attrs['name'] = 'hate'
train_df_not_hate.attrs['name'] = 'not_hate'
dfs = [train_df_hate, train_df_not_hate]
for df in dfs:
    samples = []
    for run in range(0, 3):
        sample = list(df.id.sample(int(len(df)/2), random_state = run + 1))
        samples.append(sample)
    save = pd.DataFrame(list(map(list, zip(*samples))), columns = ['mcads', 'chatgpt', 'flan'])
    save.to_csv('/kaggle/working/sample_' + df.attrs['name'], index = False)

In [None]:
sample_hate = pd.read_csv('/kaggle/working/sample_hate')
sample_not_hate = pd.read_csv('/kaggle/working/sample_not_hate')
og_cads_mcads = train_df[['labels_bin', 'text']][~train_df.id.isin(list(sample_hate.mcads) + list(sample_not_hate.mcads))]
mcads = train_df[['labels_bin', 'counterfactual_text']][train_df.id.isin(list(sample_hate.mcads) + list(sample_not_hate.mcads))].rename(columns = {'counterfactual_text': 'text'})
mcads_full = pd.concat([og_cads_mcads, mcads]).rename(columns = {'labels_bin': 'labels'})
og_cads_chatgpt = train_df[['labels_bin', 'text']][~train_df.id.isin(list(sample_hate.chatgpt) + list(sample_not_hate.chatgpt))]
chatgpt_cads = train_df[['labels_bin', 'chatgpt']][train_df.id.isin(list(sample_hate.chatgpt) + list(sample_not_hate.chatgpt))].rename(columns = {'chatgpt': 'text'})
chatgpt_full = pd.concat([og_cads_chatgpt, chatgpt_cads]).rename(columns = {'labels_bin': 'labels'})
og_cads_flan = train_df[['labels_bin', 'text']][~train_df.id.isin(list(sample_hate.flan) + list(sample_not_hate.flan))]
flan_cads = train_df[['labels_bin', 'flan_t5_xl_cad']][train_df.id.isin(list(sample_hate.flan) + list(sample_not_hate.flan))].rename(columns = {'flan_t5_xl_cad': 'text'})
flan_full = pd.concat([og_cads_flan, flan_cads]).rename(columns = {'labels_bin': 'labels'})
og_full = train_df[['labels_bin', 'text']].rename(columns = {'labels_bin': 'labels'})

In [None]:
# SVM
from sklearn import svm
X = 

In [None]:
from datasets import Dataset

tokenizer = RobertaTokenizerFast.from_pretrained(model_name)

train_df_dataset_text = Dataset.from_pandas(og_full)
train_df_dataset_mcad = Dataset.from_pandas(mcads_full)
train_df_dataset_gpt = Dataset.from_pandas(chatgpt_full)
train_df_dataset_flan = Dataset.from_pandas(flan_full)
test_df_dataset = Dataset.from_pandas(test_df)

# NOTREAL = Dataset.from_pandas(train_df.iloc[0:18])


def tokenize_function(examples):
    return tokenizer(text = examples['text'],
                   padding="max_length", 
                   truncation=True)


tokenized_train_df_text = train_df_dataset_text.map(tokenize_function, batched=True)
tokenized_test_df_text = test_df_dataset.map(tokenize_function, batched=True)
tokenized_train_df_mcad = train_df_dataset_mcad.map(tokenize_function, batched = True)
tokenized_train_df_gpt = train_df_dataset_gpt.map(tokenize_function, batched = True)
tokenized_train_df_flan = train_df_dataset_flan.map(tokenize_function, batched = True)


tokenized_train_df_flan[0]
# tokenized_NOTREAL = NOTREAL.map(tokenize_function_gpt, batched = True)

# tokenized_train_df_flan

model = RobertaForSequenceClassification.from_pretrained(model_name, num_labels=len(le.classes_)).to(device_name)

In [None]:
from sklearn.metrics import accuracy_score

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    acc = accuracy_score(labels, preds)
    return {
      'accuracy': acc,
  }


# trainer_text = Trainer(
#     model=model,                         # the instantiated 🤗 Transformers model to be trained
#     args=training_args,                  # training arguments, defined above
#     train_dataset=tokenized_train_df_text,         # training dataset
#     compute_metrics=compute_metrics      # our custom evaluation function
# )
# trainer_text.train()
# trainer_text.save_model('/kaggle/working/results/text')

# trainer_mcad = Trainer(
#     model=model,                         
#     args=training_args,                 
#     train_dataset=tokenized_train_df_mcad,      
#     compute_metrics=compute_metrics      
# )
# trainer_mcad.train()
# trainer_mcad.save_model('/kaggle/working/results/mcad')

# trainer_gpt = Trainer(
#     model=model,                         
#     args=training_args,                 
#     train_dataset=tokenized_train_df_gpt,      
#     compute_metrics=compute_metrics      
# )
# trainer_gpt.train()
# trainer_gpt.save_model('/kaggle/working/results/gpt')

trainer_flan = Trainer(
    model=model,                         
    args=training_args,                 
    train_dataset=tokenized_train_df_flan,      
    compute_metrics=compute_metrics      
)
trainer_flan.train()
trainer_flan.save_model('/kaggle/working/results/flan')