In [10]:
# Autoreload
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [11]:
from collections import defaultdict

from pathlib import Path

from tqdm.notebook import tqdm

import torch

import datasets
# Don't show progress datasets bars
datasets.disable_progress_bar()
from datasets import load_dataset

from fastchat.model import get_conversation_template

import sys
sys.path.insert(0, str(Path.cwd().parent.resolve()))
from model import get_model
from dataset import collator
from utils import get_tokenizer

In [12]:
device = torch.device(
    f"cuda:{torch.cuda.current_device()}" if torch.cuda.is_available() else "cpu"
)
device

device(type='cuda', index=0)

In [13]:
tokenizer_name = "lmsys/vicuna-7b-v1.3"

dataset_path = "super_glue"
dataset_name = "rte"

In [14]:
# Tokenizer
tokenizer = get_tokenizer(tokenizer_name)

Loading tokenizer lmsys/vicuna-7b-v1.3...
Loaded tokenizer.



## Dataset

In [15]:
# Dataset for PPO training
dataset = load_dataset(dataset_path, dataset_name, split="train")
print(dataset)

Found cached dataset super_glue (/admin/home-augustas/.cache/huggingface/datasets/super_glue/rte/1.0.3/bb9675f958ebfee0d5d6dc5476fafe38c79123727a7258d515c450873dbdbbed)


Dataset({
    features: ['premise', 'hypothesis', 'idx', 'label'],
    num_rows: 2490
})


In [16]:
# def doc_to_text(doc):
#     return "{}\nQuestion: {} True or False?\nAnswer:".format(
#         doc["premise"],
#         doc["hypothesis"],
#     )


def format_label(label):
    return "False" if label else "True"

# Dataset template
def create_get_prompt_fn(tokenizer_name):
    def get_prompt(example, answer_prefix="Answer:"):
        conv = get_conversation_template(tokenizer_name)

        message = (
            "Consider the premise below in triple backticks "
            "and the corresponding statement. Based on the information in the premise, "
            "tell whether the statement is correct. Your answer should be either true or false.\n\n"
            "Desired format:\n"
            "Answer: <your_answer>\n"
            f"Do not print \"{answer_prefix}\" again, just what you think the answer is.\n\n"
            f"Premise:\n```\n{example['premise']}\n```\n"
            f"Statement: {example['hypothesis']}\n"
            f"{answer_prefix}"
        )

        conv.append_message(conv.roles[0], message)
        conv.append_message(conv.roles[1], None)

        return conv.get_prompt()

    return get_prompt

get_prompt = create_get_prompt_fn(tokenizer_name)
input_text = get_prompt(dataset[2])
print(input_text)

A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: Consider the premise below in triple backticks and the corresponding statement. Based on the information in the premise, tell whether the statement is correct. Your answer should be either true or false.

Desired format:
Answer: <your_answer>
Do not print "Answer:" again, just what you think the answer is.

Premise:
```
Herceptin was already approved to treat the sickest breast cancer patients, and the company said, Monday, it will discuss with federal regulators the possibility of prescribing the drug for more breast cancer patients.
```
Statement: Herceptin can be used to treat breast cancer.
Answer: ASSISTANT:


## Model

In [17]:
# Model
# model = get_model(tokenizer_name, device)
model = get_model(tokenizer_name, device, load_in_8bit=True, low_cpu_mem_usage=True)

memory_usage = model.pretrained_model.get_memory_footprint() / (1024 ** 3)
print(f"{memory_usage=:.2f} GB")

Loading policy model...

is_bf16_possible=False
kwargs={'load_in_8bit': True, 'low_cpu_mem_usage': True, 'torch_dtype': None}


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded subject model with 6,738,419,713 parameters.
Model dtype: torch.float16

memory_usage=6.58 GB


In [26]:
def apply_prompt(batch):
    processed_batch = defaultdict(list)
    for premise, hypothesis in zip(batch["premise"], batch["hypothesis"]):
        example = { "premise": premise, "hypothesis": hypothesis }
        processed_batch["prompt"].append(get_prompt(example))
    
    return processed_batch

def filter_too_long(batch):
    # Substract a bit more to allow for generations to be processed
    return [len(x) < tokenizer.model_max_length - 8 for x in batch["input_ids"]]

processed_dataset = dataset.map(
    lambda batch: apply_prompt(batch), batched=True, num_proc=12
)
processed_dataset = processed_dataset.map(
    lambda batch: tokenizer(batch["prompt"]), batched=True, num_proc=12
)
processed_dataset = processed_dataset.filter(filter_too_long, batched=True, num_proc=12)

# Remove some columns
processed_dataset = processed_dataset.remove_columns(["token_type_ids"])

processed_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"], output_all_columns=True)
processed_dataset

Dataset({
    features: ['premise', 'hypothesis', 'idx', 'label', 'prompt', 'input_ids', 'attention_mask'],
    num_rows: 2490
})

In [19]:
# Push the processed dataset to hub
# processed_dataset.push_to_hub("AugustasM/imdb_vicuna", private=True)

In [27]:
from trl import PPOTrainer, PPOConfig

config = PPOConfig()

optimizer = None

# We then build the PPOTrainer, passing the model, the reference model, the tokenizer
ppo_trainer = PPOTrainer(
    config,
    model,
    ref_model=None,
    tokenizer=tokenizer,
    dataset=dataset,
    data_collator=collator,
    optimizer=optimizer,
)

In [32]:
dataloader = torch.utils.data.DataLoader(
    processed_dataset,
    # processed_dataset.shuffle(seed=42).select(range(8)),
    # processed_dataset.shuffle(seed=42).select(range(1024)),
    batch_size=64, collate_fn=collator,
    num_workers=12, shuffle=False,
)
print(f"Dataloader length: {len(dataloader)}")

generation_kwargs = {
    "top_k": 0,
    "top_p": 1.0,
    "do_sample": False,
    "pad_token_id": tokenizer.pad_token_id,
    "eos_token_id": 100_000, # why is this value like this?
    "pad_to_multiple_of": 8, # TODO: double-check, but this seems to work and to be faster
    "max_new_tokens": 1,
}

gold_outputs = []
outputs = []
for batch in tqdm(dataloader, total=len(dataloader), leave=False):
    batch_gold_outputs = [format_label(label) for label in batch["label"]]
    gold_outputs.extend(batch_gold_outputs)

    question_tensors = batch["input_ids"]

    response_tensors = ppo_trainer.generate(
        question_tensors,
        return_prompt=False,
        # length_sampler=output_length_sampler, # TODO: can be none
        batch_size=8, # TODO: generations are made in batches
        **generation_kwargs,
    )
    batch["response"] = tokenizer.batch_decode(
        response_tensors, skip_special_tokens=True, spaces_between_special_tokens=False
    )

    outputs.extend(batch["response"])

len(outputs), len(gold_outputs)

Dataloader length: 39


  0%|          | 0/39 [00:00<?, ?it/s]

(2490, 2490)

In [33]:
# for output, golden_output in zip(outputs, gold_outputs):
#     print(f"{golden_output=}\n{output=}")
#     print("-" * 80)

In [34]:
from collections import Counter

print(len(gold_outputs), len(outputs))
print(Counter(outputs))

2490 2490
Counter({'False': 1380, 'True': 1109, 'strik': 1})


In [35]:
def get_accuracy(gold_outputs, outputs):
    return sum([1 if gold == output else 0 for gold, output in zip(gold_outputs, outputs)]) / len(outputs)

get_accuracy(gold_outputs, outputs)

0.7514056224899598

In [36]:
# for output, golden_output in zip(outputs, gold_outputs):
#     if output != golden_output:
#         print(f"{golden_output=}\n{output=}")
#         print("-" * 80)

: 