In [1]:
# Autoreload
%load_ext autoreload
%autoreload 2

In [10]:
from collections import defaultdict

from pathlib import Path

from tqdm.notebook import tqdm

import torch

import datasets
# Don't show progress datasets bars
datasets.disable_progress_bar()
from datasets import load_dataset

from fastchat.model import get_conversation_template

import sys
sys.path.insert(0, str(Path.cwd().parent.resolve()))
from model import get_model
from dataset import collator
from utils import get_tokenizer

In [3]:
device = torch.device(
    f"cuda:{torch.cuda.current_device()}" if torch.cuda.is_available() else "cpu"
)
device

device(type='cuda', index=0)

In [4]:
tokenizer_name = "lmsys/vicuna-7b-v1.3"

dataset_name = "imdb"

In [5]:
# Tokenizer
tokenizer = get_tokenizer(tokenizer_name)

Loading tokenizer lmsys/vicuna-7b-v1.3...


Loaded tokenizer.



## Dataset

In [6]:
# Dataset for PPO training
dataset = load_dataset(dataset_name, split="train")
print(dataset)

Found cached dataset imdb (/admin/home-augustas/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)


Dataset({
    features: ['text', 'label'],
    num_rows: 25000
})


In [7]:
def get_sentiment(label):
    return "Positive" if label else "Negative"
    # return "positive" if label else "negative"

# Dataset template
def create_get_prompt_fn(tokenizer_name):
    def get_prompt(text, answer_prefix="Sentiment:"):
        conv = get_conversation_template(tokenizer_name)

        message = (
            "Classify the movie review as either positive or negative.\n\n"
            "Desired format:\n"
            "Sentiment: <identified_sentiment>\n"
            f"Do not print \"{answer_prefix}\" again, just the sentiment.\n\n"
            f"Movie review:\n```\n{text}\n```\n"
            f"{answer_prefix}"
        )

        conv.append_message(conv.roles[0], message)
        conv.append_message(conv.roles[1], None)

        return conv.get_prompt()

    return get_prompt

get_prompt = create_get_prompt_fn(tokenizer_name)
input_text = get_prompt("<movie review>")
print(input_text)

A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: Classify the movie review as either positive or negative.

Desired format:
Sentiment: <identified_sentiment>
Do not print "Sentiment:" again, just the sentiment.

Movie review:
```
<movie review>
```
Sentiment: ASSISTANT:


## Model

In [9]:
# Model
# model = get_model(tokenizer_name, device)
model = get_model(tokenizer_name, device, load_in_8bit=True, low_cpu_mem_usage=True)

memory_usage = model.pretrained_model.get_memory_footprint() / (1024 ** 3)
print(f"{memory_usage=:.2f} GB")

Loading policy model...

is_bf16_possible=False
kwargs={'load_in_8bit': True, 'low_cpu_mem_usage': True, 'torch_dtype': None}


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded subject model with 6,738,419,713 parameters.
Model dtype: torch.float16

memory_usage=6.58 GB


In [21]:

def apply_prompt(batch):
    processed_batch = defaultdict(list)
    for item in batch["text"]:
        processed_batch["prompt"].append(get_prompt(item))
    
    return processed_batch

def filter_too_long(batch):
    # Substract a bit more to allow for generations to be processed
    return [len(x) < tokenizer.model_max_length - 8 for x in batch["input_ids"]]

processed_dataset = dataset.map(
    lambda batch: apply_prompt(batch), batched=True, num_proc=12
)
processed_dataset = processed_dataset.map(
    lambda batch: tokenizer(batch["prompt"]), batched=True, num_proc=12
)
processed_dataset = processed_dataset.filter(filter_too_long, batched=True, num_proc=12)

# Remove some columns
processed_dataset = processed_dataset.remove_columns(["token_type_ids"])

processed_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"], output_all_columns=True)
processed_dataset

Loading cached processed dataset at /admin/home-augustas/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-707ac42cae4740a3_*_of_00012.arrow


Dataset({
    features: ['text', 'label', 'prompt', 'input_ids', 'attention_mask'],
    num_rows: 24989
})

In [23]:
# Push the processed dataset to hub
# processed_dataset.push_to_hub("AugustasM/imdb_vicuna", private=True)

In [12]:
from trl import PPOTrainer, PPOConfig

config = PPOConfig()

optimizer = None

# We then build the PPOTrainer, passing the model, the reference model, the tokenizer
ppo_trainer = PPOTrainer(
    config,
    model,
    ref_model=None,
    tokenizer=tokenizer,
    dataset=dataset,
    data_collator=collator,
    optimizer=optimizer,
)

In [15]:
dataloader = torch.utils.data.DataLoader(
    # processed_dataset,
    processed_dataset.shuffle(seed=42).select(range(1024)),
    batch_size=64, collate_fn=collator,
    num_workers=12, shuffle=False,
)
print(f"Dataloader length: {len(dataloader)}")

generation_kwargs = {
    "top_k": 0,
    "top_p": 1.0,
    "do_sample": False,
    "pad_token_id": tokenizer.pad_token_id,
    "eos_token_id": 100_000, # why is this value like this?
    "pad_to_multiple_of": 8, # TODO: double-check, but this seems to work and to be faster
    "max_new_tokens": 4,
}

gold_outputs = []
outputs = []
for batch in tqdm(dataloader, total=len(dataloader), leave=False):
    batch_gold_outputs = [get_sentiment(label) for label in batch["label"]]
    gold_outputs.extend(batch_gold_outputs)

    question_tensors = batch["input_ids"]

    response_tensors = ppo_trainer.generate(
        question_tensors,
        return_prompt=False,
        # length_sampler=output_length_sampler, # TODO: can be none
        batch_size=8, # TODO: generations are made in batches
        **generation_kwargs,
    )
    batch["response"] = tokenizer.batch_decode(
        response_tensors, skip_special_tokens=True, spaces_between_special_tokens=False
    )

    outputs.extend(batch["response"])

len(outputs), len(gold_outputs)

Dataloader length: 16


  0%|          | 0/16 [00:00<?, ?it/s]

(1024, 1024)

In [16]:
from collections import Counter

print(len(gold_outputs), len(outputs))
print(Counter(outputs))

1024 1024
Counter({'Negative': 531, 'Positive': 492, 'Neutral': 1})


In [19]:
def get_accuracy(gold_outputs, outputs):
    return sum([1 if gold == output else 0 for gold, output in zip(gold_outputs, outputs)]) / len(outputs)

get_accuracy(gold_outputs, outputs)

0.93359375

In [20]:
for output, golden_output in zip(outputs, gold_outputs):
    if output != golden_output:
        print(f"{golden_output=}\n{output=}")
        print("-" * 80)

golden_output='Negative'
output='Positive'
--------------------------------------------------------------------------------
golden_output='Positive'
output='Negative'
--------------------------------------------------------------------------------
golden_output='Negative'
output='Positive'
--------------------------------------------------------------------------------
golden_output='Positive'
output='Negative'
--------------------------------------------------------------------------------
golden_output='Positive'
output='Negative'
--------------------------------------------------------------------------------
golden_output='Positive'
output='Negative'
--------------------------------------------------------------------------------
golden_output='Positive'
output='Negative'
--------------------------------------------------------------------------------
golden_output='Negative'
output='Positive'
--------------------------------------------------------------------------------
golden_o