In [1]:
# Autoreload
%load_ext autoreload
%autoreload 2

In [2]:
# import os
# os.environ['CUDA_VISIBLE_DEVICES'] = "2"

In [3]:
from collections import defaultdict

from pathlib import Path

from tqdm.notebook import tqdm

import torch

import datasets
# Don't show progress datasets bars
datasets.disable_progress_bar()
from datasets import load_dataset

import sys
sys.path.insert(0, str(Path.cwd().parent.resolve()))
from model import get_model
from dataset import collator
from utils import get_tokenizer

In [4]:
device = torch.device(
    f"cuda:0" if torch.cuda.is_available() else "cpu"
)
device

device(type='cuda', index=0)

In [5]:
tokenizer_name = "meta-llama/Llama-2-7b-chat-hf"

dataset_path = "glue"
dataset_name = "qnli"

In [6]:
# Tokenizer
tokenizer = get_tokenizer(tokenizer_name)
tokenizer.pad_token = tokenizer.unk_token
tokenizer.model_max_length = 4096
tokenizer

Loading tokenizer meta-llama/Llama-2-7b-chat-hf...


Using pad_token, but it is not set yet.


Loaded tokenizer.



LlamaTokenizerFast(name_or_path='meta-llama/Llama-2-7b-chat-hf', vocab_size=32000, model_max_length=4096, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False), 'eos_token': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False), 'unk_token': AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False), 'pad_token': '<unk>'}, clean_up_tokenization_spaces=False)

## Dataset

In [7]:
# Dataset for PPO training
dataset = load_dataset(dataset_path, dataset_name, split="train")
dataset = dataset.shuffle(seed=42).select(range(2048))
print(dataset)

Found cached dataset glue (/admin/home-augustas/.cache/huggingface/datasets/glue/qnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
Loading cached shuffled indices for dataset at /admin/home-augustas/.cache/huggingface/datasets/glue/qnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-60510de4495de1ff.arrow


Dataset({
    features: ['question', 'sentence', 'label', 'idx'],
    num_rows: 2048
})


In [76]:
def format_label(label):
    return "No" if label else "Yes"

# Dataset template
def get_prompt(example):
    message = (
            "Consider the sentence below in triple backticks "
            "and the corresponding question. Does the sentence contain enough information "
            "to answer the question? Your answer should be either yes or no.\n\n"
            "Desired format:\n"
            "Answer: <your_answer>\n"
            f"Do not print \"Answer:\" again, just what you think the answer is. Also, do not print a gap before your answer.\n\n"
            f"Sentence:\n```\n{example['sentence']}\n```\n"
            f"Question: {example['question']}\n"
            "REMEMBER, DO NOT print \"Answer:\" again.\n"
            f"Answer:"
        )

    prompt = (
        "<s>[INST] <<SYS>>\n"
        "A chat between a curious user and an artificial intelligence assistant. The assistant gives yes or no answers to the user's questions. "
        "The assistant outputs only the yes/no answer and does not repeat the \"Answer:\" instruction.\n"
        "<</SYS>>\n\n"
        f"{message} [/INST]"
    )

    return prompt


input_text = get_prompt(dataset[0])
print(f"'{input_text}'")

'<s>[INST] <<SYS>>
A chat between a curious user and an artificial intelligence assistant. The assistant gives yes or no answers to the user's questions. The assistant outputs only the yes/no answer and does not repeat the "Answer:" instruction.
<</SYS>>

Consider the sentence below in triple backticks and the corresponding question. Does the sentence contain enough information to answer the question? Your answer should be either yes or no.

Desired format:
Answer: <your_answer>
Do not print "Answer:" again, just what you think the answer is. Also, do not print a gap before your answer.

Sentence:
```
Apparently the sailor did not connect with the soldier, as Mahan believed he was innovating the term Middle East.
```
Question: Who did not connect with the soldier?
REMEMBER, DO NOT print "Answer:" again.
Answer: [/INST]'


## Model

In [8]:
# Model
# model = get_model(tokenizer_name, device)
model = get_model(tokenizer_name, device, load_in_8bit=True, low_cpu_mem_usage=True)

memory_usage = model.pretrained_model.get_memory_footprint() / (1024 ** 3)
print(f"{memory_usage=:.2f} GB")

Loading policy model...



Downloading (…)lve/main/config.json:   0%|          | 0.00/614 [00:00<?, ?B/s]

is_bf16_possible=False
kwargs={'load_in_8bit': True, 'low_cpu_mem_usage': True, 'torch_dtype': None}


Downloading (…)fetensors.index.json:   0%|          | 0.00/26.8k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading (…)of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

Downloading (…)of-00002.safetensors:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading (…)neration_config.json:   0%|          | 0.00/188 [00:00<?, ?B/s]

Downloading (…)model.bin.index.json:   0%|          | 0.00/26.8k [00:00<?, ?B/s]

Loaded subject model with 6,738,419,713 parameters.
Model dtype: torch.float16

memory_usage=6.58 GB


In [77]:
def apply_prompt(batch):
    processed_batch = defaultdict(list)
    for sentence, question in zip(batch["sentence"], batch["question"]):
        example = { "sentence": sentence, "question": question }
        processed_batch["prompt"].append(get_prompt(example))
    
    return processed_batch

def filter_too_long(batch):
    # Substract a bit more to allow for generations to be processed
    return [len(x) < tokenizer.model_max_length - 8 for x in batch["input_ids"]]

processed_dataset = dataset.map(
    lambda batch: apply_prompt(batch), batched=True, num_proc=12
)
processed_dataset = processed_dataset.map(
    lambda batch: tokenizer(batch["prompt"]), batched=True, num_proc=12
)
processed_dataset = processed_dataset.filter(filter_too_long, batched=True, num_proc=12)

# Remove some columns
# processed_dataset = processed_dataset.remove_columns(["token_type_ids"])

processed_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"], output_all_columns=True)
processed_dataset

Dataset({
    features: ['question', 'sentence', 'label', 'idx', 'prompt', 'input_ids', 'attention_mask'],
    num_rows: 2048
})

In [78]:
print(processed_dataset[0]["prompt"])

<s>[INST] <<SYS>>
A chat between a curious user and an artificial intelligence assistant. The assistant gives yes or no answers to the user's questions. The assistant outputs only the yes/no answer and does not repeat the "Answer:" instruction.
<</SYS>>

Consider the sentence below in triple backticks and the corresponding question. Does the sentence contain enough information to answer the question? Your answer should be either yes or no.

Desired format:
Answer: <your_answer>
Do not print "Answer:" again, just what you think the answer is. Also, do not print a gap before your answer.

Sentence:
```
Apparently the sailor did not connect with the soldier, as Mahan believed he was innovating the term Middle East.
```
Question: Who did not connect with the soldier?
REMEMBER, DO NOT print "Answer:" again.
Answer: [/INST]


In [None]:
# Push the processed dataset to hub
# processed_dataset.push_to_hub("AugustasM/imdb_vicuna", private=True)

In [31]:
from trl import PPOTrainer, PPOConfig

config = PPOConfig()

optimizer = None

# We then build the PPOTrainer, passing the model, the reference model, the tokenizer
ppo_trainer = PPOTrainer(
    config,
    model,
    ref_model=None,
    tokenizer=tokenizer,
    dataset=dataset,
    data_collator=collator,
    optimizer=optimizer,
)

In [32]:
def postprocess_fn(text_batch):
    outputs = []

    for text in text_batch:
        if "Yes" in text:
            outputs.append("Yes")
        elif "No" in text:
            outputs.append("No")
        else:
            outputs.append(text)

    return outputs

In [84]:
dataloader = torch.utils.data.DataLoader(
    # processed_dataset,
    # processed_dataset.select(range(8)),
    # processed_dataset.select(range(32)),
    processed_dataset.select(range(128)),
    # processed_dataset.select(range(1024)),
    batch_size=32, collate_fn=collator,
    num_workers=12, shuffle=False,
)
print(f"Dataloader length: {len(dataloader)}")

generation_kwargs = {
    "top_k": 0,
    "top_p": 1.0,
    # "do_sample": False,
    "do_sample": True,
    "pad_token_id": tokenizer.pad_token_id,
    "eos_token_id": 100_000, # why is this value like this?
    "pad_to_multiple_of": 8, # TODO: double-check, but this seems to work and to be faster
    "max_new_tokens": 4,
}

gold_outputs = []
outputs = []
for batch in tqdm(dataloader, total=len(dataloader), leave=False):
    batch_gold_outputs = [format_label(label) for label in batch["label"]]
    gold_outputs.extend(batch_gold_outputs)

    question_tensors = batch["input_ids"]

    response_tensors = ppo_trainer.generate(
        question_tensors,
        return_prompt=False,
        # length_sampler=output_length_sampler, # TODO: can be none
        batch_size=8, # TODO: generations are made in batches
        **generation_kwargs,
    )
    batch["response"] = tokenizer.batch_decode(
        response_tensors, skip_special_tokens=True, spaces_between_special_tokens=False
    )

    # Postprocess
    # batch["response"] = postprocess_fn(batch["response"])

    outputs.extend(batch["response"])

len(outputs), len(gold_outputs)

Dataloader length: 4


  0%|          | 0/4 [00:00<?, ?it/s]

(128, 128)

In [80]:
# for output, golden_output in zip(outputs, gold_outputs):
#     print(f"{golden_output=}\n{output=}")
#     print("-" * 80)

In [85]:
from collections import Counter

print(len(gold_outputs), len(outputs))
print(Counter([x[9:] for x in outputs]))

128 128
Counter({'Yes': 119, '': 7, 're': 1, 'No': 1})


In [81]:
from collections import Counter

print(len(gold_outputs), len(outputs))
print(Counter(outputs))

32 32
Counter({' Answer: Yes': 28, ' Sure, I': 4})


In [82]:
print(Counter([x.lower() for x in outputs]))

Counter({' answer: yes': 28, ' sure, i': 4})


In [41]:
def get_accuracy(gold_outputs, outputs):
    return sum([1 if gold == output else 0 for gold, output in zip(gold_outputs, outputs)]) / len(outputs)

get_accuracy(gold_outputs, outputs)

0.490234375

In [42]:
for output, golden_output in zip(outputs, gold_outputs):
    if output != golden_output:
        print(f"{golden_output=}\n{output=}")
        print("-" * 80)

golden_output='Yes'
output='Mah'
--------------------------------------------------------------------------------
golden_output='No'
output='Yes'
--------------------------------------------------------------------------------
golden_output='No'
output='Yes'
--------------------------------------------------------------------------------
golden_output='No'
output='Yes'
--------------------------------------------------------------------------------
golden_output='Yes'
output='No'
--------------------------------------------------------------------------------
golden_output='No'
output='Yes'
--------------------------------------------------------------------------------
golden_output='No'
output='Yes'
--------------------------------------------------------------------------------
golden_output='Yes'
output='yes'
--------------------------------------------------------------------------------
golden_output='Yes'
output='No'
--------------------------------------------------------------