# 1. Install libs and setup

In [1]:
%pip install --upgrade pip
%pip install --disable-pip-version-check \
    torch==1.13.1 \
    torchdata==0.5.1 --quiet

%pip install \
    transformers==4.27.2 \
    datasets==2.11.0 \
    evaluate==0.4.0 \
    rouge_score==0.1.2 \
    peft==0.3.0 --quiet

# Installing the Reinforcement Learning library directly from github.
%pip install git+https://github.com/lvwerra/trl.git@25fa1bd
%pip install --upgrade tensorflow==2.12.0 tensorflow-probability==0.19.0
#%pip install --upgrade tensorflow tensorflow-probability

Collecting pip
  Downloading pip-24.1.2-py3-none-any.whl (1.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m25.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 23.1.2
    Uninstalling pip-23.1.2:
      Successfully uninstalled pip-23.1.2
Successfully installed pip-24.1.2
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m887.5/887.5 MB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.6/4.6 MB[0m [31m92.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m317.1/317.1 MB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m21.0/21.0 MB[0m [31m91.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m849.3/849.3 kB[0m [31m57.2 MB/s[0m eta [36m0:00:00[0m


In [1]:
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM, GenerationConfig
from datasets import load_dataset
from peft import PeftModel, PeftConfig, LoraConfig, TaskType

# trl: Transformer Reinforcement Learning library
from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead
from trl import create_reference_model
from trl.core import LengthSampler

import torch
import evaluate

import numpy as np
import pandas as pd

# tqdm library makes the loops show a smart progress meter.
from tqdm import tqdm
tqdm.pandas()

# 2. Task1. Fine-Tune FLAN-T5 with Reinforcement Learning (PPO) and PEFT to Generate Less-Toxic Summaries


## 2.1 - Load data, FLAN-T5 model, Reference/PPO model.

In [2]:
model_name="google/flan-t5-base"
huggingface_dataset_name = "knkarthick/dialogsum"

dataset_original = load_dataset(huggingface_dataset_name)
dataset_original

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading readme:   0%|          | 0.00/4.65k [00:00<?, ?B/s]

Downloading and preparing dataset csv/knkarthick--dialogsum to /root/.cache/huggingface/datasets/knkarthick___csv/knkarthick--dialogsum-cd36827d3490488d/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1...


Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/11.3M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.35M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/442k [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/knkarthick___csv/knkarthick--dialogsum-cd36827d3490488d/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 12460
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 1500
    })
    validation: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 500
    })
})

In [3]:
dataset_original['train'][0]

{'id': 'train_0',
 'dialogue': "#Person1#: Hi, Mr. Smith. I'm Doctor Hawkins. Why are you here today?\n#Person2#: I found it would be a good idea to get a check-up.\n#Person1#: Yes, well, you haven't had one for 5 years. You should have one every year.\n#Person2#: I know. I figure as long as there is nothing wrong, why go see the doctor?\n#Person1#: Well, the best way to avoid serious illnesses is to find out about them early. So try to come at least once a year for your own good.\n#Person2#: Ok.\n#Person1#: Let me see here. Your eyes and ears look fine. Take a deep breath, please. Do you smoke, Mr. Smith?\n#Person2#: Yes.\n#Person1#: Smoking is the leading cause of lung cancer and heart disease, you know. You really should quit.\n#Person2#: I've tried hundreds of times, but I just can't seem to kick the habit.\n#Person1#: Well, we have classes and some medications that might help. I'll give you more information before you leave.\n#Person2#: Ok, thanks doctor.",
 'summary': "Mr. Smith'

In [4]:
def build_dataset(model_name,
                  dataset_name,
                  input_min_text_length,
                  input_max_text_length):

    """
    Preprocess the dataset and split it into train and test parts.

    Parameters:
    - model_name (str): Tokenizer model name.
    - dataset_name (str): Name of the dataset to load.
    - input_min_text_length (int): Minimum length of the dialogues.
    - input_max_text_length (int): Maximum length of the dialogues.

    Returns:
    - dataset_splits (datasets.dataset_dict.DatasetDict): Preprocessed dataset containing train and test parts.
    """

    # load dataset (only "train" part will be enough for this lab).
    dataset = load_dataset(dataset_name, split="train")

    # Filter the dialogues of length between input_min_text_length and input_max_text_length characters.
    dataset = dataset.filter(lambda x: len(x["dialogue"]) > input_min_text_length and len(x["dialogue"]) <= input_max_text_length, batched=False)

    # Prepare tokenizer. Setting device_map="auto" allows to switch between GPU and CPU automatically.
    tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")

    def tokenize(sample):

        # Wrap each dialogue with the instruction.
        prompt = f"""
Summarize the following conversation.

{sample["dialogue"]}

Summary:
"""
        sample["input_ids"] = tokenizer.encode(prompt)

        # This must be called "query", which is a requirement of our PPO library.
        sample["query"] = tokenizer.decode(sample["input_ids"])
        return sample

    # Tokenize each dialogue.
    dataset = dataset.map(tokenize, batched=False)
    dataset.set_format(type="torch")

    # Split the dataset into train and test parts.
    dataset_splits = dataset.train_test_split(test_size=0.2, shuffle=False, seed=42)

    return dataset_splits

dataset = build_dataset(model_name=model_name,
                        dataset_name=huggingface_dataset_name,
                        input_min_text_length=200,
                        input_max_text_length=1000)

print(dataset)



Filter:   0%|          | 0/12460 [00:00<?, ? examples/s]



tokenizer_config.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

Map:   0%|          | 0/10022 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic', 'input_ids', 'query'],
        num_rows: 8017
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic', 'input_ids', 'query'],
        num_rows: 2005
    })
})


In [5]:
print(dataset['train'][0]['query'])

Summarize the following conversation. #Person1#: Hi, Mr. Smith. I'm Doctor Hawkins. Why are you here today? #Person2#: I found it would be a good idea to get a check-up. #Person1#: Yes, well, you haven't had one for 5 years. You should have one every year. #Person2#: I know. I figure as long as there is nothing wrong, why go see the doctor? #Person1#: Well, the best way to avoid serious illnesses is to find out about them early. So try to come at least once a year for your own good. #Person2#: Ok. #Person1#: Let me see here. Your eyes and ears look fine. Take a deep breath, please. Do you smoke, Mr. Smith? #Person2#: Yes. #Person1#: Smoking is the leading cause of lung cancer and heart disease, you know. You really should quit. #Person2#: I've tried hundreds of times, but I just can't seem to kick the habit. #Person1#: Well, we have classes and some medications that might help. I'll give you more information before you leave. #Person2#: Ok, thanks doctor. Summary: </s>


In [70]:
def print_number_of_trainable_model_parameters(model):
    trainable_model_params = 0
    all_model_params = 0
    for _, param in model.named_parameters():
        all_model_params += param.numel()
        if param.requires_grad:
            trainable_model_params += param.numel()
    return f"\ntrainable model parameters: {trainable_model_params}\nall model parameters: {all_model_params}\npercentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%"

In [3]:
'''
!aws s3 cp --recursive s3://dlai-generative-ai/models/peft-dialogue-summary-checkpoint/ ./peft-dialogue-summary-checkpoint-from-s3/
!ls -alh ./peft-dialogue-summary-checkpoint-from-s3/adapter_model.bin
# instruct_model_name=‘truocpham/flan-dialogue-summary-checkpoint’
#!aws s3 ls --profile profile1
!aws configure list
!aws s3 ls --profile profile
'''
#!git clone https://github.com/prajuktadey/gen-ai.git

fatal error: Unable to locate credentials
ls: cannot access './peft-dialogue-summary-checkpoint-from-s3/adapter_model.bin': No such file or directory


In [8]:
!git clone https://github.com/prajuktadey/gen-ai.git

Cloning into 'gen-ai'...
remote: Enumerating objects: 30, done.[K
remote: Counting objects: 100% (30/30), done.[K
remote: Compressing objects: 100% (28/28), done.[K
remote: Total 30 (delta 7), reused 0 (delta 0), pack-reused 0[K
Receiving objects: 100% (30/30), 13.30 MiB | 6.92 MiB/s, done.
Resolving deltas: 100% (7/7), done.


In [None]:
lora_config = LoraConfig(
    r=32, # Rank
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM # FLAN-T5
)


#model_name="google/flan-t5-base"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name,
                                              torch_dtype=torch.bfloat16)

peft_model = PeftModel.from_pretrained(model,
                                       #'./peft-dialogue-summary-checkpoint-from-s3/',
                                       './gen-ai/peft-dialogue-summary-checkpoint',
                                       lora_config=lora_config,
                                       torch_dtype=torch.bfloat16,
                                       #device_map="auto",
                                       is_trainable=True)

print(f'PEFT model parameters to be updated:\n{print_number_of_trainable_model_parameters(peft_model)}\n')

In [10]:
ppo_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(peft_model,
                                                               torch_dtype=torch.bfloat16,
                                                               is_trainable=True)

print(f'PPO model parameters to be updated (ValueHead + 769 params):\n{print_number_of_trainable_model_parameters(ppo_model)}\n')
print(ppo_model.v_head)


PPO model parameters to be updated (ValueHead + 769 params):

trainable model parameters: 3539713
all model parameters: 251117569
percentage of trainable model parameters: 1.41%

ValueHead(
  (dropout): Dropout(p=0.1, inplace=False)
  (summary): Linear(in_features=768, out_features=1, bias=True)
  (flatten): Flatten(start_dim=1, end_dim=-1)
)


During PPO, only a few parameters will be updated. Specifically, the parameters of the ValueHead. More information about this class of models can be found in the documentation. The number of trainable parameters can be computed as
, where
 is the number of input units (here
) and
 is the number of output units (you have
). The
 term in the equation takes into account the bias term.

Now create a frozen copy of the PPO which will not be fine-tuned - a reference model. The reference model will represent the LLM before detoxification. None of the parameters of the reference model will be updated during PPO training. This is on purpose.



In [11]:
ref_model = create_reference_model(ppo_model)
print(f'Reference model parameters to be updated:\n{print_number_of_trainable_model_parameters(ref_model)}\n')

Reference model parameters to be updated:

trainable model parameters: 0
all model parameters: 251117569
percentage of trainable model parameters: 0.00%



## 2.2 - Prepare Reward Model


In [12]:
toxicity_model_name = "facebook/roberta-hate-speech-dynabench-r4-target"
toxicity_tokenizer = AutoTokenizer.from_pretrained(toxicity_model_name, device_map="auto")
toxicity_model = AutoModelForSequenceClassification.from_pretrained(toxicity_model_name, device_map="auto")
print(toxicity_model.config.id2label)



tokenizer_config.json:   0%|          | 0.00/1.11k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/816 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/499M [00:00<?, ?B/s]

{0: 'nothate', 1: 'hate'}


In [18]:
#explore

non_toxic_text = "#Person 1# tells Tommy that he didn't like the movie."
print("TEXT (non_toxic_text): ", non_toxic_text)
toxicity_input_ids = toxicity_tokenizer(non_toxic_text, return_tensors="pt").input_ids

# =====LOG: Move toxicity_input_ids to the same device as the model
toxicity_input_ids = toxicity_input_ids.to(toxicity_model.device)

logits = toxicity_model(input_ids=toxicity_input_ids).logits
print(f'logits [not hate, hate]: {logits.tolist()[0]}')

# Print the probabilities for [not hate, hate]
probabilities = logits.softmax(dim=-1).tolist()[0]
print(f'probabilities [not hate, hate]: {probabilities}')

# get the logits for "not hate" - this is the reward!
not_hate_index = 0
nothate_reward = (logits[:, not_hate_index]).tolist()
print(f'reward (high): {nothate_reward}')

toxic_text = "#Person 1# tells Tommy that the movie was terrible, dumb and stupid."
print("TEXT (toxic_text): ", toxic_text)

toxicity_input_ids = toxicity_tokenizer(toxic_text, return_tensors="pt").input_ids

# =====LOG: Move toxicity_input_ids to the same device as the model
toxicity_input_ids = toxicity_input_ids.to(toxicity_model.device)

logits = toxicity_model(toxicity_input_ids).logits
print(f'logits [not hate, hate]: {logits.tolist()[0]}')

# Print the probabilities for [not hate, hate]
probabilities = logits.softmax(dim=-1).tolist()[0]
print(f'probabilities [not hate, hate]: {probabilities}')

# Get the logits for "not hate" - this is the reward!
nothate_reward = (logits[:, not_hate_index]).tolist()
print(f'reward (low): {nothate_reward}')

TEXT (non_toxic_text):  #Person 1# tells Tommy that he didn't like the movie.
logits [not hate, hate]: [3.114102363586426, -2.489619016647339]
probabilities [not hate, hate]: [0.9963293671607971, 0.0036706042010337114]
reward (high): [3.114102363586426]
TEXT (toxic_text):  #Person 1# tells Tommy that the movie was terrible, dumb and stupid.
logits [not hate, hate]: [-0.6921164393424988, 0.372270792722702]
probabilities [not hate, hate]: [0.2564719319343567, 0.7435280084609985]
reward (low): [-0.6921164393424988]


In [19]:
device = 0 if torch.cuda.is_available() else "cpu"

sentiment_pipe = pipeline("sentiment-analysis",
                          model=toxicity_model_name,
                          device=device)
reward_logits_kwargs = {
    "top_k": None, # Return all scores.
    "function_to_apply": "none", # Set to "none" to retrieve raw logits.
    "batch_size": 16
}

reward_probabilities_kwargs = {
    "top_k": None, # Return all scores.
    "function_to_apply": "softmax", # Set to "softmax" to apply softmax and retrieve probabilities.
    "batch_size": 16
}

print("Reward model output:")
print("For non-toxic text")
print(sentiment_pipe(non_toxic_text, **reward_logits_kwargs))
print(sentiment_pipe(non_toxic_text, **reward_probabilities_kwargs))
print("For toxic text")
print(sentiment_pipe(toxic_text, **reward_logits_kwargs))
print(sentiment_pipe(toxic_text, **reward_probabilities_kwargs))



Reward model output:
For non-toxic text
[{'label': 'nothate', 'score': 3.114102363586426}, {'label': 'hate', 'score': -2.489619016647339}]
[{'label': 'nothate', 'score': 0.9963293671607971}, {'label': 'hate', 'score': 0.0036706042010337114}]
For toxic text
[{'label': 'hate', 'score': 0.372270792722702}, {'label': 'nothate', 'score': -0.6921164393424988}]
[{'label': 'hate', 'score': 0.7435280084609985}, {'label': 'nothate', 'score': 0.2564719319343567}]


## 2.3 - Evaluate Toxicity

In [20]:
toxicity_evaluator = evaluate.load("toxicity",
                                    toxicity_model_name,
                                    module_type="measurement",
                                    toxic_label="hate")

toxicity_score = toxicity_evaluator.compute(predictions=[
    non_toxic_text
])

print("Toxicity score for non-toxic text:")
print(toxicity_score["toxicity"])

toxicity_score = toxicity_evaluator.compute(predictions=[
    toxic_text
])

print("\nToxicity score for toxic text:")
print(toxicity_score["toxicity"])

Downloading builder script:   0%|          | 0.00/6.08k [00:00<?, ?B/s]

Toxicity score for non-toxic text:
[0.003670616541057825]

Toxicity score for toxic text:
[0.7435289621353149]


In [38]:
def evaluate_toxicity(model,
                      toxicity_evaluator,
                      tokenizer,
                      dataset,
                      num_samples):

    """
    Preprocess the dataset and split it into train and test parts.

    Parameters:
    - model (trl model): Model to be evaluated.
    - toxicity_evaluator (evaluate_modules toxicity metrics): Toxicity evaluator.
    - tokenizer (transformers tokenizer): Tokenizer to be used.
    - dataset (dataset): Input dataset for the evaluation.
    - num_samples (int): Maximum number of samples for the evaluation.

    Returns:
    tuple: A tuple containing two numpy.float64 values:
    - mean (numpy.float64): Mean of the samples toxicity.
    - std (numpy.float64): Standard deviation of the samples toxicity.
    """

    max_new_tokens=100

    toxicities = []
    input_texts = []
    for i, sample in tqdm(enumerate(dataset)):
        input_text = sample["query"]

        if i > num_samples:
            break

        input_ids = tokenizer(input_text, return_tensors="pt", padding=True).input_ids.to(device)

        # LOG === Ensure input_ids are on the same device as the model
        #input_ids = input_ids.to(model.device)
        # In the `evaluate_toxicity` function, change line 36:
        #input_ids = input_ids.to(model.base_model.device)  # Access device from base model
        input_ids = input_ids.to(next(model.parameters()).device)  # Get device from model parameters


        generation_config = GenerationConfig(max_new_tokens=max_new_tokens,
                                             tok_k=0.0,
                                             top_p=1.0,
                                             do_sample=True)

        response_token_ids = model.generate(input_ids=input_ids,
                                            generation_config=generation_config)

        generated_text = tokenizer.decode(response_token_ids[0], skip_special_tokens=True)

        toxicity_score = toxicity_evaluator.compute(predictions=[(input_text + " " + generated_text)])

        toxicities.extend(toxicity_score["toxicity"])

    # Compute mean & std using np.
    mean = np.mean(toxicities)
    std = np.std(toxicities)

    return mean, std

In [22]:
dataset["test"]
display(dataset)
sub_dataset = dataset["test"].select(range(5))
display(sub_dataset)

DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic', 'input_ids', 'query'],
        num_rows: 8017
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic', 'input_ids', 'query'],
        num_rows: 2005
    })
})

Dataset({
    features: ['id', 'dialogue', 'summary', 'topic', 'input_ids', 'query'],
    num_rows: 5
})

In [23]:
tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")

mean_before_detoxification, std_before_detoxification = evaluate_toxicity(model=ref_model,
                                                                          toxicity_evaluator=toxicity_evaluator,
                                                                          tokenizer=tokenizer,
                                                                          #dataset=dataset["test"],
                                                                          dataset=sub_dataset,
                                                                          num_samples=10)

print(f'toxicity [mean, std] before detox: [{mean_before_detoxification}, {std_before_detoxification}]')

5it [00:11,  2.25s/it]

toxicity [mean, std] before detox: [0.007212279317900539, 0.006383725121443723]





## 2.4 Perform Fine-Tuning to Detoxify the Summaries

### 2.4.1 Initialize PPOTrainer

In [24]:
def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

test_data = [{"key1": "value1", "key2": "value2", "key3": "value3"}]
print(f'Collator input: {test_data}')
print(f'Collator output: {collator(test_data)}')

Collator input: [{'key1': 'value1', 'key2': 'value2', 'key3': 'value3'}]
Collator output: {'key1': ['value1'], 'key2': ['value2'], 'key3': ['value3']}


In [25]:
learning_rate=1.41e-5
max_ppo_epochs=1
mini_batch_size=4
batch_size=16

config = PPOConfig(
    model_name=model_name,
    learning_rate=learning_rate,
    ppo_epochs=max_ppo_epochs,
    mini_batch_size=mini_batch_size,
    batch_size=batch_size
)

ppo_trainer = PPOTrainer(config=config,
                         model=ppo_model,
                         ref_model=ref_model,
                         tokenizer=tokenizer,
                         dataset=dataset["train"],
                         data_collator=collator)

### 2.4.2 Fine-Tune the Model

The fine-tuning loop consists of the following main steps:

1. Get the query responses from the policy LLM (PEFT model).
2. Get sentiments for query/responses from hate speech RoBERTa model.
3. Optimize policy with PPO using the (query, response, reward) triplet.


The operation is running if you see the following metrics appearing:

* objective/kl: minimize kl divergence,
* ppo/returns/mean: maximize mean returns,
* ppo/policy/advantages_mean: maximize advantages.

In [26]:
output_min_length = 100
output_max_length = 400
output_length_sampler = LengthSampler(output_min_length, output_max_length)

generation_kwargs = {
    "min_length": 5,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True
}

reward_kwargs = {
    "top_k": None, # Return all scores.
    "function_to_apply": "none", # You want the raw logits without softmax.
    "batch_size": 16
}

max_ppo_steps = 10

for step, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    # Break when you reach max_steps.
    if step >= max_ppo_steps:
        break

    prompt_tensors = batch["input_ids"]

    # Get response from FLAN-T5/PEFT LLM.
    summary_tensors = []

    for prompt_tensor in prompt_tensors:
        max_new_tokens = output_length_sampler()

        generation_kwargs["max_new_tokens"] = max_new_tokens
        summary = ppo_trainer.generate(prompt_tensor, **generation_kwargs)

        summary_tensors.append(summary.squeeze()[-max_new_tokens:])

    # This needs to be called "response".
    batch["response"] = [tokenizer.decode(r.squeeze()) for r in summary_tensors]

    # Compute reward outputs.
    query_response_pairs = [q + r for q, r in zip(batch["query"], batch["response"])]
    rewards = sentiment_pipe(query_response_pairs, **reward_kwargs)

    # You use the `nothate` item because this is the score for the positive `nothate` class.
    reward_tensors = [torch.tensor(reward[not_hate_index]["score"]) for reward in rewards]

    # Run PPO step.
    stats = ppo_trainer.step(prompt_tensors, summary_tensors, reward_tensors)
    ppo_trainer.log_stats(stats, batch, reward_tensors)

    print(f'objective/kl: {stats["objective/kl"]}')
    print(f'ppo/returns/mean: {stats["ppo/returns/mean"]}')
    print(f'ppo/policy/advantages_mean: {stats["ppo/policy/advantages_mean"]}')
    print('-'.join('' for x in range(100)))

0it [00:00, ?it/s]You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
1it [00:24, 24.77s/it]

objective/kl: 30.958946228027344
ppo/returns/mean: -0.4403095543384552
ppo/policy/advantages_mean: -1.4564196959554465e-08
---------------------------------------------------------------------------------------------------


2it [00:45, 22.65s/it]

objective/kl: 29.040725708007812
ppo/returns/mean: -0.32841476798057556
ppo/policy/advantages_mean: 7.589676442876225e-09
---------------------------------------------------------------------------------------------------


3it [01:05, 21.35s/it]

objective/kl: 26.405418395996094
ppo/returns/mean: -0.3114323019981384
ppo/policy/advantages_mean: -1.9711601240146592e-09
---------------------------------------------------------------------------------------------------


4it [01:24, 20.23s/it]

objective/kl: 23.95636558532715
ppo/returns/mean: -0.11056496202945709
ppo/policy/advantages_mean: -8.830317810293309e-09
---------------------------------------------------------------------------------------------------


5it [01:44, 20.32s/it]

objective/kl: 28.987567901611328
ppo/returns/mean: -0.3714064359664917
ppo/policy/advantages_mean: 5.108677036957943e-09
---------------------------------------------------------------------------------------------------


6it [02:03, 19.82s/it]

objective/kl: 24.597410202026367
ppo/returns/mean: -0.23697587847709656
ppo/policy/advantages_mean: -1.188559028975078e-08
---------------------------------------------------------------------------------------------------


7it [02:26, 20.89s/it]

objective/kl: 31.728172302246094
ppo/returns/mean: -0.5597826838493347
ppo/policy/advantages_mean: -1.3312913438312535e-08
---------------------------------------------------------------------------------------------------


8it [02:46, 20.59s/it]

objective/kl: 23.9593448638916
ppo/returns/mean: -0.19444648921489716
ppo/policy/advantages_mean: -1.4930579439464964e-08
---------------------------------------------------------------------------------------------------


9it [03:04, 19.70s/it]

objective/kl: 23.22824478149414
ppo/returns/mean: -0.19702671468257904
ppo/policy/advantages_mean: -7.753449438041571e-09
---------------------------------------------------------------------------------------------------


10it [03:23, 20.35s/it]

objective/kl: 27.128936767578125
ppo/returns/mean: -0.3643034100532532
ppo/policy/advantages_mean: -3.0329161404551996e-09
---------------------------------------------------------------------------------------------------





In [39]:
mean_after_detoxification, std_after_detoxification = evaluate_toxicity(model=ppo_model,
                                                                        toxicity_evaluator=toxicity_evaluator,
                                                                        tokenizer=tokenizer,
                                                                        #dataset=dataset["test"],
                                                                        dataset=sub_dataset,
                                                                        num_samples=10)
print(f'toxicity [mean, std] after detox: [{mean_after_detoxification}, {std_after_detoxification}]')

5it [00:06,  1.20s/it]

toxicity [mean, std] after detox: [0.0067623034352436665, 0.006098020820219592]





# Task2. Idiom To Straightforward Expression

## 3.1 load data and model, checkpoint, pert, ppo, reference model

In [57]:
from datasets import load_dataset

column_names= ['idiom', 'straightforward']
custom_dataset = load_dataset("csv", data_files={"train": "./sample_data/train.csv" , "validate": "./sample_data/eval.csv", "test": "./sample_data/test.csv"}, column_names=column_names)
display(custom_dataset)



  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['idiom', 'straightforward'],
        num_rows: 105
    })
    validate: Dataset({
        features: ['idiom', 'straightforward'],
        num_rows: 7
    })
    test: Dataset({
        features: ['idiom', 'straightforward'],
        num_rows: 14
    })
})

In [72]:
def build_dataset(model_name,
                  dataset,
                  input_min_text_length,
                  input_max_text_length):

    """
    Preprocess the dataset and split it into train and test parts.

    Parameters:
    - model_name (str): Tokenizer model name.
    - dataset_name (str): Name of the dataset to load.
    - input_min_text_length (int): Minimum length of the dialogues.
    - input_max_text_length (int): Maximum length of the dialogues.

    Returns:
    - dataset (datasets.dataset_dict.DatasetDict): Preprocessed dataset containing train and test parts.
    """
    # Filter the idiomatic expression of length between input_min_text_length and input_max_text_length characters.
    dataset = dataset.filter(lambda x: len(x["idiom"]) > input_min_text_length and len(x["idiom"]) <= input_max_text_length, batched=False)

    # Prepare tokenizer. Setting device_map="auto" allows to switch between GPU and CPU automatically.
    tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")

    def tokenize(sample):
        # Wrap each idiom with the instruction.
        prompt = f"""
        Trun this idiomatic expression into a more straightforward statement?\nidiom: {sample["idiom"]} \nA straightforward expression:"""

        sample["input_ids"] = tokenizer.encode(prompt)

        # This must be called "query", which is a requirement of our PPO library.
        sample["query"] = tokenizer.decode(sample["input_ids"])
        return sample

    # Tokenize each idiom.
    dataset = dataset.map(tokenize, batched=False)
    dataset.set_format(type="torch")

    return dataset


tokenized_dataset = build_dataset(model_name=model_name,
                        dataset=custom_dataset,
                        input_min_text_length=10,
                        input_max_text_length=200)
print(custom_dataset)
print(tokenized_dataset)



DatasetDict({
    train: Dataset({
        features: ['idiom', 'straightforward'],
        num_rows: 105
    })
    validate: Dataset({
        features: ['idiom', 'straightforward'],
        num_rows: 7
    })
    test: Dataset({
        features: ['idiom', 'straightforward'],
        num_rows: 14
    })
})
DatasetDict({
    train: Dataset({
        features: ['idiom', 'straightforward', 'input_ids', 'query'],
        num_rows: 105
    })
    validate: Dataset({
        features: ['idiom', 'straightforward', 'input_ids', 'query'],
        num_rows: 7
    })
    test: Dataset({
        features: ['idiom', 'straightforward', 'input_ids', 'query'],
        num_rows: 14
    })
})


In [110]:
print(tokenized_dataset["train"][0]['input_ids'])
print(tokenized_dataset["train"][1]['input_ids'])

tensor([ 7953,    29,    48,     3, 19916,  4992,  3893,   139,     3,     9,
           72, 11753,  2493,    58,     3, 19916,    51,    10,    27,  4037,
            3,     9, 17625,    11,    34,    47,     3,     9,  1466,    13,
         4340,   227,   767,    13,   761,     5,    71, 11753,  3893,    10,
            1])
tensor([ 7953,    29,    48,     3, 19916,  4992,  3893,   139,     3,     9,
           72, 11753,  2493,    58,     3, 19916,    51,    10,    94,    47,
            3,     9,  1466,    13,  4340,    12,  1903,    82,  2535,    22,
            7,   794,     5,    71, 11753,  3893,    10,     1])


In [92]:
model_name="google/flan-t5-base"

In [120]:
lora_config = LoraConfig(
    r=32, # Rank
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM # FLAN-T5
)

model_name="google/flan-t5-base"

base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name,
                                              torch_dtype=torch.bfloat16)

ckp_path="./pert-idiom-ckp"
peft_model = PeftModel.from_pretrained(base_model,
                                       ckp_path,
                                       lora_config=lora_config,
                                       torch_dtype=torch.bfloat16,
                                       device_map="auto",
                                       is_trainable=True)


print(f'PEFT model parameters to be updated:\n{print_number_of_trainable_model_parameters(peft_model)}')



PEFT model parameters to be updated:

trainable model parameters: 7077888
all model parameters: 254655744
percentage of trainable model parameters: 2.78%


In [121]:
ppo_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(peft_model,
                                                               torch_dtype=torch.bfloat16,
                                                               is_trainable=True)

print(f'PPO model parameters to be updated (ValueHead + 769 params):\n{print_number_of_trainable_model_parameters(ppo_model)}\n')
print(ppo_model.v_head)

PPO model parameters to be updated (ValueHead + 769 params):

trainable model parameters: 7078657
all model parameters: 254656513
percentage of trainable model parameters: 2.78%

ValueHead(
  (dropout): Dropout(p=0.1, inplace=False)
  (summary): Linear(in_features=768, out_features=1, bias=True)
  (flatten): Flatten(start_dim=1, end_dim=-1)
)


During PPO, only a few parameters will be updated. Specifically, the parameters of the ValueHead. More information about this class of models can be found in the documentation. The number of trainable parameters can be computed as
, where
 is the number of input units (here
) and
 is the number of output units (you have
). The
 term in the equation takes into account the bias term.

Now create a frozen copy of the PPO which will not be fine-tuned - a reference model. The reference model will represent the LLM before detoxification. None of the parameters of the reference model will be updated during PPO training. This is on purpose.



In [122]:
ref_model = create_reference_model(ppo_model)
print(f'Reference model parameters to be updated:\n{print_number_of_trainable_model_parameters(ref_model)}\n')

Reference model parameters to be updated:

trainable model parameters: 0
all model parameters: 254656513
percentage of trainable model parameters: 0.00%



In [123]:
toxicity_model_name = "facebook/roberta-hate-speech-dynabench-r4-target"
toxicity_tokenizer = AutoTokenizer.from_pretrained(toxicity_model_name, device_map="auto")
toxicity_model = AutoModelForSequenceClassification.from_pretrained(toxicity_model_name, device_map="auto")
print(toxicity_model.config.id2label)

{0: 'nothate', 1: 'hate'}


## 3.2 fine-tuned

In [124]:
def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

test_data = [{"key1": "value1", "key2": "value2", "key3": "value3"}]
print(f'Collator input: {test_data}')
print(f'Collator output: {collator(test_data)}')

Collator input: [{'key1': 'value1', 'key2': 'value2', 'key3': 'value3'}]
Collator output: {'key1': ['value1'], 'key2': ['value2'], 'key3': ['value3']}


In [125]:
learning_rate=1.41e-5
max_ppo_epochs=5
mini_batch_size=4
batch_size=16

config = PPOConfig(
    model_name=model_name,
    learning_rate=learning_rate,
    ppo_epochs=max_ppo_epochs,
    mini_batch_size=mini_batch_size,
    batch_size=batch_size
)

ppo_trainer = PPOTrainer(config=config,
                         model=ppo_model,
                         ref_model=ref_model,
                         tokenizer=tokenizer,
                         dataset=tokenized_dataset["train"],
                         data_collator=collator)

In [128]:
output_min_length = 10
output_max_length = 400
output_length_sampler = LengthSampler(output_min_length, output_max_length)

generation_kwargs = {
    "min_length": 5,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True
}

reward_kwargs = {
    "top_k": None, # Return all scores.
    "function_to_apply": "none", # You want the raw logits without softmax.
    "batch_size": 20
}

max_ppo_steps = 105*5



LOG_list = []

epochs = 5
for epoch in tqdm(range(epochs), "epoch: "):
  print("#=== epoch: ", epoch)
  for step, batch in tqdm(enumerate(ppo_trainer.dataloader)):
      # Break when you reach max_steps.
      print("#=== step: ", step, ", batch: ", len(batch['query']))
      LOG_list.append(batch['query'])

      if step >= max_ppo_steps:
          print("== max steps", step)
          break

      #print("#=== batch", batch)
      prompt_tensors = batch["input_ids"]

      # Get response from FLAN-T5/PEFT LLM.
      summary_tensors = []

      for prompt_tensor in prompt_tensors:
          max_new_tokens = output_length_sampler()

          generation_kwargs["max_new_tokens"] = max_new_tokens
          summary = ppo_trainer.generate(prompt_tensor, **generation_kwargs)

          summary_tensors.append(summary.squeeze()[-max_new_tokens:])

      # This needs to be called "response".
      batch["response"] = [tokenizer.decode(r.squeeze()) for r in summary_tensors]

      # Compute reward outputs.
      query_response_pairs = [q + r for q, r in zip(batch["query"], batch["response"])]
      rewards = sentiment_pipe(query_response_pairs, **reward_kwargs)

      # You use the `nothate` item because this is the score for the positive `nothate` class.
      reward_tensors = [torch.tensor(reward[not_hate_index]["score"]) for reward in rewards]

      # Run PPO step.
      stats = ppo_trainer.step(prompt_tensors, summary_tensors, reward_tensors)
      ppo_trainer.log_stats(stats, batch, reward_tensors)

      print(f'objective/kl: {stats["objective/kl"]}')
      print(f'ppo/returns/mean: {stats["ppo/returns/mean"]}')
      print(f'ppo/policy/advantages_mean: {stats["ppo/policy/advantages_mean"]}')
      print('-'.join('' for x in range(100)))

epoch:   0%|          | 0/5 [00:00<?, ?it/s]

#=== epoch:  0



0it [00:00, ?it/s][A

#=== step:  0 , batch:  16



1it [00:10, 10.80s/it][A

objective/kl: 2.297739267349243
ppo/returns/mean: 1.1895471811294556
ppo/policy/advantages_mean: -6.00558669461293e-09
---------------------------------------------------------------------------------------------------
#=== step:  1 , batch:  16



2it [00:20, 10.32s/it][A

objective/kl: 3.1768717765808105
ppo/returns/mean: 1.187159538269043
ppo/policy/advantages_mean: -1.5251968576990294e-08
---------------------------------------------------------------------------------------------------
#=== step:  2 , batch:  16



3it [00:30, 10.07s/it][A

objective/kl: 3.376898765563965
ppo/returns/mean: 1.3107110261917114
ppo/policy/advantages_mean: 2.84553514084962e-09
---------------------------------------------------------------------------------------------------
#=== step:  3 , batch:  16



4it [00:38,  9.39s/it][A

objective/kl: 3.9081149101257324
ppo/returns/mean: 1.1737864017486572
ppo/policy/advantages_mean: 2.0361916597266827e-08
---------------------------------------------------------------------------------------------------
#=== step:  4 , batch:  16



5it [00:48,  9.57s/it][A

objective/kl: 4.283422470092773
ppo/returns/mean: 1.0961557626724243
ppo/policy/advantages_mean: 6.65236132846303e-09
---------------------------------------------------------------------------------------------------
#=== step:  5 , batch:  16



6it [00:58,  9.79s/it]
epoch:  20%|██        | 1/5 [00:58<03:55, 58.75s/it]

objective/kl: 4.811492919921875
ppo/returns/mean: 1.109157681465149
ppo/policy/advantages_mean: 2.744127947096331e-08
---------------------------------------------------------------------------------------------------
#=== epoch:  1



0it [00:00, ?it/s][A

#=== step:  0 , batch:  16



1it [00:09,  9.71s/it][A

objective/kl: 2.163726806640625
ppo/returns/mean: 1.4286088943481445
ppo/policy/advantages_mean: -2.1991287013634064e-08
---------------------------------------------------------------------------------------------------
#=== step:  1 , batch:  16



2it [00:18,  9.17s/it][A

objective/kl: 3.0968892574310303
ppo/returns/mean: 1.4044406414031982
ppo/policy/advantages_mean: -1.9979658816282608e-08
---------------------------------------------------------------------------------------------------
#=== step:  2 , batch:  16



3it [00:26,  8.75s/it][A

objective/kl: 3.505913257598877
ppo/returns/mean: 1.1035131216049194
ppo/policy/advantages_mean: -3.4544656024593223e-09
---------------------------------------------------------------------------------------------------
#=== step:  3 , batch:  16



4it [00:35,  8.77s/it][A

objective/kl: 4.885770797729492
ppo/returns/mean: 1.3058750629425049
ppo/policy/advantages_mean: -5.849500439580879e-09
---------------------------------------------------------------------------------------------------
#=== step:  4 , batch:  16



5it [00:44,  8.68s/it][A

objective/kl: 5.975161552429199
ppo/returns/mean: 0.9673738479614258
ppo/policy/advantages_mean: 4.001809994491623e-09
---------------------------------------------------------------------------------------------------
#=== step:  5 , batch:  16



6it [00:52,  8.73s/it]
epoch:  40%|████      | 2/5 [01:51<02:44, 55.00s/it]

objective/kl: 3.3734655380249023
ppo/returns/mean: 1.3782621622085571
ppo/policy/advantages_mean: 5.1021256552985506e-08
---------------------------------------------------------------------------------------------------
#=== epoch:  2



0it [00:00, ?it/s][A

#=== step:  0 , batch:  16



1it [00:08,  8.07s/it][A

objective/kl: 3.1964926719665527
ppo/returns/mean: 1.212503433227539
ppo/policy/advantages_mean: -1.6620099074771133e-08
---------------------------------------------------------------------------------------------------
#=== step:  1 , batch:  16



2it [00:17,  9.15s/it][A

objective/kl: 3.208568572998047
ppo/returns/mean: 1.1326969861984253
ppo/policy/advantages_mean: -1.2867903187441243e-08
---------------------------------------------------------------------------------------------------
#=== step:  2 , batch:  16



3it [00:26,  8.78s/it][A

objective/kl: 3.3826751708984375
ppo/returns/mean: 1.2385021448135376
ppo/policy/advantages_mean: -6.650777262251495e-09
---------------------------------------------------------------------------------------------------
#=== step:  3 , batch:  16



4it [00:34,  8.61s/it][A

objective/kl: 3.3704261779785156
ppo/returns/mean: 1.64800226688385
ppo/policy/advantages_mean: 1.7884529412981465e-08
---------------------------------------------------------------------------------------------------
#=== step:  4 , batch:  16



5it [00:42,  8.41s/it][A

objective/kl: 3.7384347915649414
ppo/returns/mean: 1.2044007778167725
ppo/policy/advantages_mean: 8.088743896905726e-09
---------------------------------------------------------------------------------------------------
#=== step:  5 , batch:  16



6it [00:50,  8.46s/it]
epoch:  60%|██████    | 3/5 [02:41<01:46, 53.08s/it]

objective/kl: 2.582101821899414
ppo/returns/mean: 1.4126431941986084
ppo/policy/advantages_mean: -1.2027815410320386e-09
---------------------------------------------------------------------------------------------------
#=== epoch:  3



0it [00:00, ?it/s][A

#=== step:  0 , batch:  16



1it [00:08,  8.70s/it][A

objective/kl: 5.014919281005859
ppo/returns/mean: 1.0453990697860718
ppo/policy/advantages_mean: 1.8226256059961088e-09
---------------------------------------------------------------------------------------------------
#=== step:  1 , batch:  16



2it [00:16,  8.46s/it][A

objective/kl: 3.6472058296203613
ppo/returns/mean: 1.4857972860336304
ppo/policy/advantages_mean: -5.0168637244496495e-08
---------------------------------------------------------------------------------------------------
#=== step:  2 , batch:  16



3it [00:26,  8.76s/it][A

objective/kl: 2.2739086151123047
ppo/returns/mean: 1.1315993070602417
ppo/policy/advantages_mean: 6.126924745331053e-09
---------------------------------------------------------------------------------------------------
#=== step:  3 , batch:  16



4it [00:34,  8.63s/it][A

objective/kl: 4.607113838195801
ppo/returns/mean: 1.1180663108825684
ppo/policy/advantages_mean: -1.766993307228404e-08
---------------------------------------------------------------------------------------------------
#=== step:  4 , batch:  16



5it [00:42,  8.40s/it][A

objective/kl: 3.2744126319885254
ppo/returns/mean: 1.2545665502548218
ppo/policy/advantages_mean: 1.1344662986800813e-08
---------------------------------------------------------------------------------------------------
#=== step:  5 , batch:  16



6it [00:50,  8.45s/it]
epoch:  80%|████████  | 4/5 [03:32<00:52, 52.15s/it]

objective/kl: 2.2705037593841553
ppo/returns/mean: 1.660112738609314
ppo/policy/advantages_mean: -1.7172901989326306e-09
---------------------------------------------------------------------------------------------------
#=== epoch:  4



0it [00:00, ?it/s][A

#=== step:  0 , batch:  16



1it [00:07,  7.43s/it][A

objective/kl: 3.7409653663635254
ppo/returns/mean: 0.9941158294677734
ppo/policy/advantages_mean: 1.2886078870621986e-08
---------------------------------------------------------------------------------------------------
#=== step:  1 , batch:  16



2it [00:15,  8.00s/it][A

objective/kl: 4.120509147644043
ppo/returns/mean: 1.5187931060791016
ppo/policy/advantages_mean: 2.6768682825917267e-09
---------------------------------------------------------------------------------------------------
#=== step:  2 , batch:  16



3it [00:23,  7.96s/it][A

objective/kl: 3.847064971923828
ppo/returns/mean: 1.538988471031189
ppo/policy/advantages_mean: -1.1317742298899702e-08
---------------------------------------------------------------------------------------------------
#=== step:  3 , batch:  16



4it [00:32,  8.30s/it][A

objective/kl: 2.4815659523010254
ppo/returns/mean: 1.219718337059021
ppo/policy/advantages_mean: -1.6029174432219406e-08
---------------------------------------------------------------------------------------------------
#=== step:  4 , batch:  16



5it [00:40,  8.35s/it][A

objective/kl: 3.5949759483337402
ppo/returns/mean: 1.365090012550354
ppo/policy/advantages_mean: -1.6780949962935665e-08
---------------------------------------------------------------------------------------------------
#=== step:  5 , batch:  16



6it [00:49,  8.24s/it]
epoch: 100%|██████████| 5/5 [04:22<00:00, 52.42s/it]

objective/kl: 2.4623682498931885
ppo/returns/mean: 1.1212338209152222
ppo/policy/advantages_mean: 1.2877970689828544e-08
---------------------------------------------------------------------------------------------------





In [131]:
print("=====SAVE====\n")
ppo_model_path="./ppo-checkpoint-local-ph5"
ppo_trainer.model.save_pretrained(ppo_model_path)
tokenizer.save_pretrained(ppo_model_path)

=====SAVE====



('./ppo-checkpoint-local-ph5/tokenizer_config.json',
 './ppo-checkpoint-local-ph5/special_tokens_map.json',
 './ppo-checkpoint-local-ph5/spiece.model',
 './ppo-checkpoint-local-ph5/added_tokens.json',
 './ppo-checkpoint-local-ph5/tokenizer.json')

In [111]:
tokenized_dataset['test']

Dataset({
    features: ['idiom', 'straightforward', 'input_ids', 'query'],
    num_rows: 14
})

In [132]:
from peft import PeftModel, PeftConfig

peft_model_base = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base", torch_dtype=torch.bfloat16)
#tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")

final_model = PeftModel.from_pretrained(peft_model_base,
                                       './ppo-checkpoint-local-ph5',
                                       torch_dtype=torch.bfloat16,
                                       device_map="auto",
                                       is_trainable=False)

print(print_number_of_trainable_model_parameters(final_model))

idioms = tokenized_dataset['test']
model_results = []

for idx in range(len(tokenized_dataset['test'])):

    #input_ids = tokenized_dataset['test'][idx]['input_ids']
    #input_ids = torch.stack([example['input_ids'] for example in tokenized_dataset['test']])
    input_ids = tokenized_dataset['test'][idx]['input_ids'].unsqueeze(0) # Add a batch dimension

    outputs = final_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200))

    for output in outputs:
      text_output = tokenizer.decode(output, skip_special_tokens=True)
      model_results.append(text_output)

    #text_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
    #model_results.append(text_output)


display(model_results)

zipped_results = list(zip(tokenized_dataset['test']['idiom'], model_results))

df = pd.DataFrame(zipped_results, columns = ['idioms', 'RLHF_model_results'])
df.to_csv('RLHF_model_results.csv')
display(df)




trainable model parameters: 0
all model parameters: 254655744
percentage of trainable model parameters: 0.00%


['He expected a warm welcomem, but instead, he was given the cold shoulder.',
 'to',
 'idiom',
 "You didn't reach a conclusion about project yet.",
 'My mom will go bananas if I forgot to feed the dog again.',
 "You'll end up going bananas.",
 'Things go south.',
 'The performance in the last quarter went south.',
 "You're going to break a leg tonight.",
 'You can do it.',
 'I bought it on sale, and it still cost me a million won.',
 'This ard cost me an arm and a leg!',
 'The exam was a piece of cake.',
 "You're going to be a piece of cake for you."]

Unnamed: 0,idioms,RLHF_model_results
0,"He expected a warm welcomem, but instead, he w...","He expected a warm welcomem, but instead, he w..."
1,"Despite his attempts to reconcile, she continu...",to
2,I was suprised at the party last night. Jessi ...,idiom
3,We didn't reach a conclusion about project yet...,You didn't reach a conclusion about project yet.
4,My mom will go bananas if I forgot to feed the...,My mom will go bananas if I forgot to feed the...
5,I'll end up going bananas if I have to work in...,You'll end up going bananas.
6,Things go south.,Things go south.
7,John's performance in the last quarter went so...,The performance in the last quarter went south.
8,Break a leg tonight.,You're going to break a leg tonight.
9,I am sure you can do it. Break a leg!,You can do it.
