In [1]:
# Autoreload
%load_ext autoreload
%autoreload 2

In [65]:
from pathlib import Path

from tqdm.notebook import tqdm

import torch

import datasets
# Don't show progress datasets bars
datasets.disable_progress_bar()
from datasets import load_dataset

from fastchat.model import get_conversation_template

import sys
sys.path.insert(0, str(Path.cwd().parent.resolve()))
from model import get_model
from dataset import collator
from utils import get_tokenizer

In [3]:
device = torch.device(
    f"cuda:{torch.cuda.current_device()}" if torch.cuda.is_available() else "cpu"
)
device

device(type='cuda', index=0)

In [4]:
tokenizer_name = "lmsys/vicuna-7b-v1.3"

dataset_name = "imdb"

In [5]:
# Tokenizer
tokenizer = get_tokenizer(tokenizer_name)

Loading tokenizer lmsys/vicuna-7b-v1.3...
Loaded tokenizer.



## Dataset

In [57]:
# Dataset for PPO training
dataset = load_dataset(dataset_name, split="train")
print(dataset)

Found cached dataset imdb (/admin/home-augustas/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)


Dataset({
    features: ['text', 'label'],
    num_rows: 25000
})


In [69]:
def get_sentiment(label):
    return "Positive" if label else "Negative"
    # return "positive" if label else "negative"

# Dataset template
def create_get_prompt_fn(tokenizer_name):
    def get_prompt(text, answer_prefix="Sentiment:"):
        conv = get_conversation_template(tokenizer_name)

        message = (
            "Classify the movie review as either positive or negative.\n\n"
            "Desired format:\n"
            "Sentiment: <identified_sentiment>\n"
            f"Do not print \"{answer_prefix}\" again, just the sentiment.\n\n"
            f"Movie review:\n```\n{text}\n```\n"
            f"{answer_prefix}"
        )

        conv.append_message(conv.roles[0], message)
        conv.append_message(conv.roles[1], None)

        return conv.get_prompt()

    return get_prompt

get_prompt = create_get_prompt_fn(tokenizer_name)
input_text = get_prompt("<movie review>")
print(input_text)

A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: Classify the movie review as either positive or negative.

Desired format:
Sentiment: <identified_sentiment>
Do not print "Sentiment:" again, just the sentiment.

Movie review:
```
<movie review>
```
Sentiment: ASSISTANT:


## Model

In [8]:
# Model
# model = get_model(tokenizer_name, device)
model = get_model(tokenizer_name, device, load_in_8bit=True)

memory_usage = model.pretrained_model.get_memory_footprint() / (1024 ** 3)
print(f"{memory_usage=:.2f} GB")

Loading policy model...

is_bf16_possible=False
kwargs={'load_in_8bit': True, 'torch_dtype': None}


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded subject model with 6,738,419,713 parameters.
Model dtype: torch.float16

memory_usage=6.58 GB


In [107]:
prompt_max_len = max(
    len(row["input_ids"]) for row in processed_dataset
)
prompt_max_len

1999

In [109]:
from collections import defaultdict

def apply_prompt(batch):
    processed_batch = defaultdict(list)
    for item in batch["text"]:
        processed_batch["prompt"].append(get_prompt(item))
    
    return processed_batch

def filter_too_long(batch):
    # Substract a bit more to allow for generations to be processed
    return [len(x) < tokenizer.model_max_length - 8 for x in batch["input_ids"]]

processed_dataset = dataset.map(
    lambda batch: apply_prompt(batch), batched=True, num_proc=12
)
processed_dataset = processed_dataset.map(
    lambda batch: tokenizer(batch["prompt"]), batched=True, num_proc=12
)
processed_dataset = processed_dataset.filter(
    filter_too_long,
    batched=True
)

# Remove some columns
processed_dataset = processed_dataset.remove_columns(["token_type_ids"])

processed_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"], output_all_columns=True)
processed_dataset

Loading cached processed dataset at /admin/home-augustas/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-707ac42cae4740a3_*_of_00012.arrow
Loading cached processed dataset at /admin/home-augustas/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-a04ea056e5dee5a2_*_of_00012.arrow
Loading cached processed dataset at /admin/home-augustas/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-682f7216ae3da56e.arrow


Dataset({
    features: ['text', 'label', 'prompt', 'input_ids', 'attention_mask'],
    num_rows: 24989
})

In [113]:
from trl import PPOTrainer, PPOConfig

config = PPOConfig()

optimizer = None

# We then build the PPOTrainer, passing the model, the reference model, the tokenizer
ppo_trainer = PPOTrainer(
    config,
    model,
    ref_model=None,
    tokenizer=tokenizer,
    dataset=dataset,
    data_collator=collator,
    optimizer=optimizer,
)

In [117]:
dataloader = torch.utils.data.DataLoader(
    processed_dataset.select(range(32)),
    batch_size=4, collate_fn=collator,
    num_workers=12, shuffle=False,
)
print(f"Dataloader length: {len(dataloader)}")

generation_kwargs = {
    "top_k": 0,
    "top_p": 1.0,
    "do_sample": False,
    "pad_token_id": tokenizer.pad_token_id,
    "eos_token_id": 100_000, # why is this value like this?
    "pad_to_multiple_of": 8, # TODO: double-check, but this seems to work and to be faster
    "max_new_tokens": 4,
}

gold_outputs = []
outputs = []
for batch in tqdm(dataloader, total=len(dataloader), leave=False):
    batch_gold_outputs = [get_sentiment(label) for label in batch["label"]]
    gold_outputs.extend(batch_gold_outputs)

    question_tensors = batch["input_ids"]

    response_tensors = ppo_trainer.generate(
        question_tensors,
        return_prompt=False,
        # length_sampler=output_length_sampler, # TODO: can be none
        batch_size=4, # TODO: generations are made in batches
        **generation_kwargs,
    )
    batch["response"] = tokenizer.batch_decode(
        response_tensors, skip_special_tokens=True, spaces_between_special_tokens=False
    )

    outputs.extend(batch["response"])

len(outputs), len(gold_outputs)

Dataloader length: 8


  0%|          | 0/8 [00:00<?, ?it/s]

(32, 32)

In [118]:
for output, golden_output in zip(outputs, gold_outputs):
    print(f"{golden_output=}\n{output=}")
    print("-" * 80)

golden_output='Negative'
output='Positive'
--------------------------------------------------------------------------------
golden_output='Negative'
output='Negative'
--------------------------------------------------------------------------------
golden_output='Negative'
output='Negative'
--------------------------------------------------------------------------------
golden_output='Negative'
output='Negative'
--------------------------------------------------------------------------------
golden_output='Negative'
output='Negative'
--------------------------------------------------------------------------------
golden_output='Negative'
output='Negative'
--------------------------------------------------------------------------------
golden_output='Negative'
output='Negative'
--------------------------------------------------------------------------------
golden_output='Negative'
output='Negative'
--------------------------------------------------------------------------------
golden_o

In [111]:
# TODO: do I need to set the pad_token?
generation_kwargs = {
    "top_k": 0,
    "top_p": 1.0,
    "do_sample": False,
    # "pad_to_multiple_of": 8,
    "pad_token_id": tokenizer.pad_token_id,
    "eos_token_id": 100_000, # why is this value like this?
    "max_new_tokens": 4,
}

for i, example in tqdm(enumerate(processed_dataset), total=len(dataset), leave=False):
    label = get_sentiment(example["label"])
    print(f"{label=}")

    inputs = {
        "input_ids": example["input_ids"].unsqueeze(0).to(device),
        "attention_mask": example["attention_mask"].unsqueeze(0).to(device),
    }
    # print(inputs)

    output_ids = model.generate(
        **inputs,
        **generation_kwargs,
    )
    output_ids = output_ids[0][len(inputs["input_ids"][0]):]
    outputs = tokenizer.decode(
        output_ids, skip_special_tokens=True, spaces_between_special_tokens=False
    )
    print(f"{outputs=}")
    print("-" * 100)

    if i >= 4: break

  0%|          | 0/25000 [00:00<?, ?it/s]

label='Negative'
outputs='Positive'
----------------------------------------------------------------------------------------------------
label='Negative'
outputs='Negative'
----------------------------------------------------------------------------------------------------
label='Negative'
outputs='Negative'
----------------------------------------------------------------------------------------------------
label='Negative'
outputs='Negative'
----------------------------------------------------------------------------------------------------
label='Negative'
outputs='Negative'
----------------------------------------------------------------------------------------------------


In [64]:
print(f"{conv.roles[0]}: {input_text}")
print(f"{conv.roles[1]}: {outputs}")

USER: Classify the movie review as either positive or negative.

Desired format:
Sentiment: <identified_sentiment>
Do not print "Sentiment:" again, just the sentiment.

Movie review:
```
Oh, brother...after hearing about this ridiculous film for umpteen years all I can think of is that old Peggy Lee song..<br /><br />"Is that all there is??" ...I was just an early teen when this smoked fish hit the U.S. I was too young to get in the theater (although I did manage to sneak into "Goodbye Columbus"). Then a screening at a local film museum beckoned - Finally I could see this film, except now I was as old as my parents were when they schlepped to see it!!<br /><br />The ONLY reason this film was not condemned to the anonymous sands of time was because of the obscenity case sparked by its U.S. release. MILLIONS of people flocked to this stinker, thinking they were going to see a sex film...Instead, they got lots of closeups of gnarly, repulsive Swedes, on-street interviews in bland shopping