In [1]:
from trl import PPOConfig, PPOTrainer
import utils
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    AutoModel,
    Trainer,
    TrainingArguments,
    BertModel,
    pipeline,
    AutoModelForSequenceClassification,
    BitsAndBytesConfig,
)
import yaml
import getpass
import wandb
from typing import Dict, Any
import torch as t
from peft import get_peft_config, get_peft_model, LoraConfig, TaskType
from tqdm import tqdm
import trl
import importlib

device = t.device("cuda" if t.cuda.is_available() else "cpu")

In [2]:
# RUN THIS BLOCK IF YOU CHANGE UTILS BUT DON'T WANT TO RERUN WHOLE NOTEBOOK
!nvidia-smi
# importlib.reload(utils)

Sun May 12 16:00:18 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100X                   On  | 00000000:C6:00.0 Off |                    0 |
| N/A   56C    P0              70W / 300W |     21MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [3]:
def reward_fn(
    model: AutoModel,
    reward_tokenizer: AutoTokenizer,
    prompt_text: list[str],
    response_text: list[str],
    device: str,
) -> list[t.FloatTensor]:
    """Compute the reward for a given response to a prompt.

    Args:
        model (AutoModel): Huggingface model.
        tokenizer (AutoTokenizer): Huggingface tokenizer.
        prompt_text (list[str]): List of strings representing the prompt.
        response_text (list[str]): List of strings representing the response.
        device (str, optional): Device to run the model on. Defaults to 'cpu'.

    Returns:
        list[float]: A list of floats representing the reward.

    """
    with t.no_grad():
        encoding = reward_tokenizer(
            prompt_text,
            response_text,
            truncation=True,
            max_length=512,
            padding='max_length',
            return_tensors='pt',
        )
        encoding = encoding.to(device)

        logits = model(**encoding).logits
        # scores = logits.cpu().numpy().flatten().tolist()

        return logits

def setup_logging(hps: Dict[str, Any], log_wandb):
    # Choose logging and checkpoint saving directory
    logdir = utils.choose_log_dir(
        f"{utils.run_dir}/{hps['dataset_name']}/training/{hps['training_algorithm']}",
        debug=hps["debug"],
    )

    # Add a couple of keys to the hps object and save it as a yaml file
    hps["logdir"] = logdir

    hps["training_kwargs"]["run_name"] = "/".join(logdir.split("/")[-2:])
    hps["user"] = getpass.getuser()
    hps["tags"] += [
        hps["dataset"]["name"],
        "training",
        hps["training_algorithm"],
    ]
    with open(f"{logdir}/hps.yaml", "w") as f:
        yaml.dump(hps, f)

    # If not in debug mode, setup wandb logging
    if not hps["debug"] or log_wandb:
        wandb.init(
            project="dpo_rlhf_generalization",
            dir=logdir,
            name=hps["training_kwargs"]["run_name"],
            config=utils.wandb_configify(hps),
            tags=hps["tags"],
            save_code=True,
            settings=wandb.Settings(code_dir="."),
        )

    print(f"Hyperparameters:\n{hps}\n")
    return logdir

In [4]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
)

In [5]:
def custom_collate(batch):
    input_ids = [item['input_ids'] for item in batch]
    queries = [item['query'] for item in batch]

    max_length = max(len(ids) for ids in input_ids)
    input_ids = [[tokenizer.pad_token_id] * (max_length - len(ids)) + ids for ids in input_ids]

    input_ids = t.tensor(input_ids)
    return {'input_ids': input_ids, 'queries': queries}
    
def tokenize(sample):
    sample["input_ids"] = tokenizer.encode(
        sample["query"].replace("</s>",""),
    )

    # sample["input_ids"] = tokenizer(
    #     sample["query"],
    #     truncation=True,
    #     max_length=512,
    #     padding='max_length',
    #     return_tensors='pt',
    # )['input_ids']
    # sample["input_ids"] = sample['input_ids'].squeeze(0)
    return sample

def collator(data):
    return {key: [d[key] for d in data] for key in data[0]}


In [6]:
# RUN THIS BLOCK IF YOU CHANGE YAML FILE BUT DON'T WANT TO RERUN WHOLE NOTEBOOK

args = 'hyperparams/rlhf.yaml'
with open(
    args
) as f:
    hps = yaml.load(f, Loader=yaml.FullLoader)


In [7]:
# load model
tokenizer, model = utils.load_model(
    hps["model"],
    reward_model=False,
    eval=False,
    quantized=True,
    bnb_config=bnb_config,
)
# tokenizer.padding_side = 'left'
model.config.pad_token_id = tokenizer.eos_token_id

print(tokenizer)
# hps["generator_peft_config_kwargs"]

model = trl.AutoModelForCausalLMWithValueHead.from_pretrained(model)
  # load_in_4bit=True,
  # peft_config=hps["peft_config_class"](hps["generator_peft_config_kwargs"]))


# load reward model
reward_model = AutoModelForSequenceClassification.from_pretrained(hps["rm_path"])
reward_model = reward_model.to(t.device("cuda:0")).eval()
tokenizer_reward = AutoTokenizer.from_pretrained(hps["rm_path"])
reward_model.config.pad_token_id = tokenizer.eos_token_id


`low_cpu_mem_usage` was None, now set to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



LlamaTokenizerFast(name_or_path='./drive/hh-sft-instruct-7b/sft_model', vocab_size=32000, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='left', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '</s>'}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
	0: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}




In [8]:
# for layer_idx, layer in enumerate(model.pretrained_model.model.layers):
#     if layer_idx < 30:  # Adjust the index based on zero-indexing
#         for param in layer.parameters():
#             print(layer, param.requires_grad)

In [9]:
hps["debug"] = False

# Load and process dataset. Make eval set smaller for speed reasons.
dataset = utils.load_dataset(tokenizer, **hps["dataset"], debug=hps["debug"])
test_size = min(len(dataset["test"]), 2_000)
dataset["test"] = dataset["test"].shuffle(seed=42).select(range(test_size))

dataset = dataset.rename_column("prompt", "query")
dataset = dataset.map(tokenize, batched=False)
dataset = dataset.remove_columns(["chosen", "rejected"])

print("Dataset size:", len(dataset['train']))

Dataset size: 144720


In [10]:
# To keep debug runs short

# if hps["debug"]:
#     hps["training_kwargs"]["max_steps"] = 5

config = PPOConfig(
    model_name="mistralai/Mistral-7B-Instruct-v0.2",
    # **hps["training_kwargs"]
    batch_size=hps["training_kwargs"]["batch_size"],
    gradient_accumulation_steps=hps["training_kwargs"]["gradient_accumulation_steps"],
    mini_batch_size=hps["training_kwargs"]["mini_batch_size"],
    learning_rate=float(hps["training_kwargs"]["learning_rate"]),
    log_with="wandb",
    optimize_device_cache = True,
    
)

# sent_kwargs = {
#     "return_all_scores": True,
#     "function_to_apply": "none",
#     "batch_size": 4,
# }

ppo_trainer = PPOTrainer(
    model=model,
    config=config,
    dataset=dataset['train'],
    tokenizer=tokenizer,  
    data_collator=collator,
)

# dl = ppo_trainer.prepare_dataloader(dataset['train'], data_collator=custom_collate)
# num_epochs = 2

generation_kwargs = {
    "min_length": 10,
    # "temperature": 0.7,
    "top_k": 0,
    "top_p": .9,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
    "max_new_tokens": 100,
}

# ppo_trainer.train(dl, num_epochs = 1)



ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmgerov[0m. Use [1m`wandb login --relogin`[0m to force relogin


allocated_memory = t.cuda.memory_allocated()
print(f"memory allocated: {allocated_memory / (2**30)} / ~80 GBs")

In [11]:
# Setting logging
# logdir = setup_logging(hps, True)

In [None]:
TOK = "[/INST]"

epochs = 1
for epoch in tqdm(range(epochs), "epoch: "):
    for i, batch in tqdm(enumerate(ppo_trainer.dataloader)):
        allocated_memory = t.cuda.memory_allocated()
        print(f"memory allocated: {allocated_memory / (2**30)}")

        inputs = [t.tensor(sublist) for sublist in batch['input_ids']]

        # query_tensors = t.stack(batch['input_ids'],1)
        # print(query_tensors.shape)
        # query_tensors = [tensor.view(-1) for tensor in query_tensors]
        #### Get response from SFTModel
        response_tensors = ppo_trainer.generate(inputs, **generation_kwargs)

        batch["response"] = [
            tokenizer.decode(r.squeeze()) for r in response_tensors
        ]
        response_strings = []
        response_tensors_sliced = []
        
        for response in batch['response']:
            response_string = response.split(TOK)[-1]
            response_strings.append(response_string)
            response_tensors_sliced.append(t.tensor(tokenizer(response_string)['input_ids'][2:]))
        
        allocated_memory = t.cuda.memory_allocated()
        print(f"memory allocated: {allocated_memory / (2**30)}")        
        # print(batch['query'])
        # print(batch['response'])
        #### Compute reward score
        chosen_scores = list(reward_fn(reward_model, tokenizer_reward, batch["query"], response_strings, device).flatten())
        t.cuda.empty_cache()
        #### Run PPO step
        allocated_memory = t.cuda.memory_allocated()
        print(f"memory allocated: {allocated_memory / (2**30)}")
        # for (i, response) in enumerate(response_tensors):
        #     if len(response) == 1:
        #         chosen_scores[i] -= 5

        # print(chosen_scores)

        if i % 20 == 0:
            for (query, response, score) in zip(batch['query'],  response_strings, chosen_scores):
                print('QUERY: ' + query)
                print('RESPONSE: ' + response)
                print('SCORES: ' + str(score))
                print("\n\n")
            
        stats = ppo_trainer.step(inputs, response_tensors_sliced, chosen_scores)
        ppo_trainer.log_stats(stats, batch, chosen_scores)
        del stats, batch, chosen_scores
        # t.cuda.empty_cache()

        # wandb.log(stats)
        if i % 100 == 0:
            ppo_trainer.save_pretrained(f"rlhf_saved/test/step_{i}")
#### Save model


epoch:   0%|          | 0/1 [00:00<?, ?it/s]
0it [00:00, ?it/s][AYou're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


memory allocated: 11.221648693084717
memory allocated: 11.229649066925049
memory allocated: 11.229649543762207
QUERY: <s>[INST] Does getting a college education, lead to longer life expectancy overall? I've wondered. [/INST]According to several studies, people who go to college live longer lives, and on average they’re healthier and happier.</s>[INST] Tell me more about any of what you just stated, I'm curious! [/INST]People who went to college tend to have better social and career connections and also better physical health.  They're more likely to vote in elections and vote for policy positions that benefit them or other people in the long-term.</s>[INST] I'm curious about those assertions. Can you name a specific study or two? [/INST]Sure!  
I can tell you that according to research published by Brookings, 50% of college graduates in the 50's were still alive in the 90's, compared to 27% of non-graduates.

I can tell you that according to a Harvard study, the total economic contribu


1it [00:52, 52.60s/it][A

memory allocated: 11.238348007202148
memory allocated: 11.238330841064453
memory allocated: 11.238331317901611



2it [01:38, 48.56s/it][A

memory allocated: 11.238330841064453
memory allocated: 11.238333702087402
memory allocated: 11.23833417892456



3it [02:25, 47.82s/it][A

memory allocated: 11.238333702087402
memory allocated: 11.238335132598877
memory allocated: 11.238335609436035



4it [03:09, 46.39s/it][A

memory allocated: 11.238335132598877
memory allocated: 11.238327026367188
memory allocated: 11.238327503204346



5it [03:50, 44.48s/it][A

memory allocated: 11.238327026367188
memory allocated: 11.238341808319092
memory allocated: 11.23834228515625



6it [04:41, 46.82s/it][A

memory allocated: 11.238341808319092
memory allocated: 11.238336086273193
memory allocated: 11.238336563110352



7it [05:30, 47.46s/it][A

memory allocated: 11.238336086273193
memory allocated: 11.238319873809814
memory allocated: 11.238320350646973



8it [06:14, 46.36s/it][A

memory allocated: 11.238319873809814
memory allocated: 11.238340854644775
memory allocated: 11.238341331481934



9it [07:06, 48.09s/it][A

memory allocated: 11.238340854644775
memory allocated: 11.238332271575928
memory allocated: 11.238332748413086



10it [07:52, 47.40s/it][A

memory allocated: 11.238332271575928
memory allocated: 11.238362789154053
memory allocated: 11.238363265991211



11it [08:49, 50.30s/it][A

memory allocated: 11.238362789154053
memory allocated: 11.238334655761719
memory allocated: 11.238335132598877



12it [09:36, 49.41s/it][A

memory allocated: 11.238334655761719
memory allocated: 11.238343238830566
memory allocated: 11.238343715667725



13it [10:26, 49.41s/it][A

memory allocated: 11.238343238830566
memory allocated: 11.238328456878662
memory allocated: 11.23832893371582



14it [11:09, 47.62s/it][A

memory allocated: 11.238328456878662
memory allocated: 11.238372802734375
memory allocated: 11.238373279571533



15it [12:06, 50.44s/it][A

memory allocated: 11.238372802734375
memory allocated: 11.23838996887207
memory allocated: 11.238390445709229



16it [13:14, 55.73s/it][A

memory allocated: 11.23838996887207
memory allocated: 11.23832082748413
memory allocated: 11.238321304321289



17it [14:01, 53.12s/it][A

memory allocated: 11.23832082748413
memory allocated: 11.23832082748413
memory allocated: 11.238321304321289



18it [14:41, 49.24s/it][A

memory allocated: 11.23832082748413
memory allocated: 11.238329410552979
memory allocated: 11.238329887390137



19it [15:32, 49.67s/it][A

memory allocated: 11.238329410552979
memory allocated: 11.238338947296143
memory allocated: 11.2383394241333



20it [16:13, 47.14s/it][A

memory allocated: 11.238338947296143
memory allocated: 11.238336563110352
memory allocated: 11.23833703994751



21it [17:00, 46.94s/it][A

memory allocated: 11.238336563110352
memory allocated: 11.238348484039307
memory allocated: 11.238348960876465



22it [17:52, 48.61s/it][A

memory allocated: 11.238348484039307
memory allocated: 11.238330364227295
memory allocated: 11.238330841064453



23it [18:41, 48.63s/it][A

memory allocated: 11.238330364227295
memory allocated: 11.238321781158447
memory allocated: 11.238322257995605



24it [19:27, 47.97s/it][A

memory allocated: 11.238321781158447
memory allocated: 11.238341331481934
memory allocated: 11.238341808319092



25it [20:12, 46.91s/it][A

memory allocated: 11.238341331481934
memory allocated: 11.238317012786865
memory allocated: 11.238317489624023



26it [20:55, 45.75s/it][A

memory allocated: 11.238317012786865
memory allocated: 11.238329887390137
memory allocated: 11.238330364227295



27it [21:44, 46.71s/it][A

memory allocated: 11.238329887390137
memory allocated: 11.238342761993408
memory allocated: 11.238343238830566



28it [22:32, 47.19s/it][A

memory allocated: 11.238342761993408
memory allocated: 11.2383451461792
memory allocated: 11.238345623016357



29it [23:26, 49.24s/it][A

memory allocated: 11.2383451461792
memory allocated: 11.238336086273193
memory allocated: 11.238336563110352



30it [24:15, 48.99s/it][A

memory allocated: 11.238336086273193
memory allocated: 11.238338947296143
memory allocated: 11.2383394241333



31it [25:05, 49.52s/it][A

memory allocated: 11.238338947296143
memory allocated: 11.238361358642578
memory allocated: 11.238361835479736



32it [26:01, 51.42s/it][A

memory allocated: 11.238361358642578
memory allocated: 11.23833417892456
memory allocated: 11.238334655761719



33it [26:48, 50.05s/it][A

memory allocated: 11.23833417892456
memory allocated: 11.238315105438232
memory allocated: 11.23831558227539



34it [27:31, 47.79s/it][A

memory allocated: 11.238315105438232
memory allocated: 11.238331317901611
memory allocated: 11.23833179473877



35it [28:20, 48.20s/it][A

memory allocated: 11.238331317901611
memory allocated: 11.238327503204346
memory allocated: 11.238327980041504



36it [29:07, 48.08s/it][A

memory allocated: 11.238327503204346
memory allocated: 11.238314628601074
memory allocated: 11.238315105438232



37it [29:48, 45.75s/it][A

memory allocated: 11.238314628601074
memory allocated: 11.23832082748413
memory allocated: 11.238321304321289



38it [30:34, 45.82s/it][A

memory allocated: 11.23832082748413
memory allocated: 11.238348007202148
memory allocated: 11.238348484039307



39it [31:25, 47.58s/it][A

memory allocated: 11.238348007202148
memory allocated: 11.238339900970459
memory allocated: 11.238340377807617



40it [32:17, 48.78s/it][A

memory allocated: 11.238339900970459
memory allocated: 11.238349914550781
memory allocated: 11.23835039138794



41it [33:09, 49.81s/it][A

memory allocated: 11.238349914550781
memory allocated: 11.238342761993408
memory allocated: 11.238343238830566



42it [34:02, 50.55s/it][A

memory allocated: 11.238342761993408
memory allocated: 11.238330841064453
memory allocated: 11.238331317901611



43it [34:46, 48.80s/it][A

memory allocated: 11.238330841064453
memory allocated: 11.238336086273193
memory allocated: 11.238336563110352



44it [35:35, 48.86s/it][A

memory allocated: 11.238336086273193
memory allocated: 11.238346576690674
memory allocated: 11.238347053527832



45it [36:27, 49.62s/it][A

memory allocated: 11.238346576690674
memory allocated: 11.238334655761719
memory allocated: 11.238335132598877



46it [37:14, 48.88s/it][A

memory allocated: 11.238334655761719
memory allocated: 11.238329887390137
memory allocated: 11.238330364227295



47it [37:59, 47.69s/it][A

memory allocated: 11.238329887390137
memory allocated: 11.238332271575928
memory allocated: 11.238332748413086



48it [38:46, 47.53s/it][A

memory allocated: 11.238332271575928
memory allocated: 11.238357543945312
memory allocated: 11.23835802078247



49it [39:38, 48.90s/it][A

memory allocated: 11.238357543945312
memory allocated: 11.238337516784668
memory allocated: 11.238337993621826



50it [40:26, 48.76s/it][A

memory allocated: 11.238337516784668
memory allocated: 11.238318920135498
memory allocated: 11.238319396972656



51it [41:09, 46.85s/it][A

memory allocated: 11.238318920135498
memory allocated: 11.238368034362793
memory allocated: 11.238368511199951



52it [42:05, 49.78s/it][A

memory allocated: 11.238368034362793
memory allocated: 11.238332271575928
memory allocated: 11.238332748413086



53it [42:55, 49.62s/it][A

memory allocated: 11.238332271575928
memory allocated: 11.238323211669922
memory allocated: 11.23832368850708



54it [43:39, 48.04s/it][A

memory allocated: 11.238323211669922
memory allocated: 11.23834228515625
memory allocated: 11.238342761993408



55it [44:29, 48.72s/it][A

memory allocated: 11.23834228515625
memory allocated: 11.238349914550781
memory allocated: 11.23835039138794



56it [45:21, 49.47s/it][A

memory allocated: 11.238349914550781
memory allocated: 11.23831844329834
memory allocated: 11.238318920135498



57it [46:04, 47.71s/it][A

memory allocated: 11.23831844329834
memory allocated: 11.238332748413086
memory allocated: 11.238333225250244



58it [46:47, 46.37s/it][A

memory allocated: 11.238332748413086
memory allocated: 11.238317489624023
memory allocated: 11.238317966461182



59it [47:31, 45.44s/it][A

memory allocated: 11.238317489624023
memory allocated: 11.238362789154053
memory allocated: 11.238363265991211



60it [48:26, 48.42s/it][A

memory allocated: 11.238362789154053
memory allocated: 11.238328456878662
memory allocated: 11.23832893371582



61it [49:15, 48.56s/it][A

memory allocated: 11.238328456878662
memory allocated: 11.238337516784668
memory allocated: 11.238337993621826



62it [50:04, 48.76s/it][A

memory allocated: 11.238337516784668
memory allocated: 11.238329410552979
memory allocated: 11.238329887390137



63it [50:47, 47.13s/it][A

memory allocated: 11.238329410552979
memory allocated: 11.238339900970459
memory allocated: 11.238340377807617



64it [51:38, 48.10s/it][A

memory allocated: 11.238339900970459
memory allocated: 11.238328456878662
memory allocated: 11.23832893371582



65it [52:25, 47.91s/it][A

memory allocated: 11.238328456878662
memory allocated: 11.238324642181396
memory allocated: 11.238325119018555



66it [53:10, 47.09s/it][A

memory allocated: 11.238324642181396
memory allocated: 11.238353252410889
memory allocated: 11.238353729248047



67it [53:58, 47.10s/it][A

memory allocated: 11.238353252410889
memory allocated: 11.238341331481934
memory allocated: 11.238341808319092



68it [54:45, 47.18s/it][A

memory allocated: 11.238341331481934
memory allocated: 11.238311767578125
memory allocated: 11.238312244415283



69it [55:27, 45.68s/it][A

memory allocated: 11.238311767578125
memory allocated: 11.23835802078247
memory allocated: 11.238358497619629



70it [56:19, 47.54s/it][A

memory allocated: 11.23835802078247
memory allocated: 11.238335132598877
memory allocated: 11.238335609436035



71it [57:04, 46.69s/it][A

memory allocated: 11.238335132598877
memory allocated: 11.238323211669922
memory allocated: 11.23832368850708



72it [57:54, 47.65s/it][A

memory allocated: 11.238323211669922
memory allocated: 11.238344192504883
memory allocated: 11.238344669342041



73it [58:46, 49.20s/it][A

memory allocated: 11.238344192504883
memory allocated: 11.238322257995605
memory allocated: 11.238322734832764



74it [59:30, 47.46s/it][A

memory allocated: 11.238322257995605
memory allocated: 11.238317489624023
memory allocated: 11.238317966461182



75it [1:00:07, 44.24s/it][A

memory allocated: 11.238317489624023
memory allocated: 11.23831558227539
memory allocated: 11.238316059112549



76it [1:00:47, 42.97s/it][A

memory allocated: 11.23831558227539
memory allocated: 11.238341808319092
memory allocated: 11.23834228515625



77it [1:01:37, 45.21s/it][A

memory allocated: 11.238341808319092
memory allocated: 11.23833417892456
memory allocated: 11.238334655761719



78it [1:02:28, 47.05s/it][A

memory allocated: 11.23833417892456
memory allocated: 11.238344669342041
memory allocated: 11.2383451461792



79it [1:03:17, 47.42s/it][A

memory allocated: 11.238344669342041
memory allocated: 11.238337993621826
memory allocated: 11.238338470458984



80it [1:04:07, 48.41s/it][A

memory allocated: 11.238337993621826
memory allocated: 11.238356590270996
memory allocated: 11.238357067108154



81it [1:05:06, 51.53s/it][A

memory allocated: 11.238356590270996
memory allocated: 11.238341331481934
memory allocated: 11.238341808319092



82it [1:06:03, 53.09s/it][A

memory allocated: 11.238341331481934
memory allocated: 11.238351821899414
memory allocated: 11.238352298736572



83it [1:06:55, 52.77s/it][A

memory allocated: 11.238351821899414
memory allocated: 11.238322257995605
memory allocated: 11.238322734832764



84it [1:07:39, 50.23s/it][A

memory allocated: 11.238322257995605
memory allocated: 11.238321781158447
memory allocated: 11.238322257995605



85it [1:08:29, 50.01s/it][A

memory allocated: 11.238321781158447
memory allocated: 11.238327026367188
memory allocated: 11.238327503204346



86it [1:09:14, 48.65s/it][A

memory allocated: 11.238327026367188


In [20]:
dir(ppo_trainer)

['__annotations__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_decode_arg',
 '_early_stop',
 '_encode_arg',
 '_filter_kwargs',
 '_from_pretrained',
 '_generate_batched',
 '_hub_mixin_coders',
 '_hub_mixin_config',
 '_hub_mixin_info',
 '_hub_mixin_init_parameters',
 '_hub_mixin_inject_config',
 '_hub_mixin_jsonable_custom_types',
 '_hub_mixin_jsonable_default_values',
 '_is_jsonable',
 '_kl_penalty',
 '_load_as_pickle',
 '_load_as_safetensor',
 '_prepare_deepspeed',
 '_remove_unused_columns',
 '_save_pretrained',
 '_set_signature_columns_if_needed',
 '_show_tokens',
 '_signature_columns',
 '_step_safety_checker',
 '_tag_names',
 'accelerator',
 'batched_forward_p

In [24]:
inputs = [t.tensor(sublist) for sublist in batch['input_ids']]

# query_tensors = t.stack(batch['input_ids'],1)
# print(query_tensors.shape)
# query_tensors = [tensor.view(-1) for tensor in query_tensors]
#### Get response from SFTModel
response_tensors = ppo_trainer.generate(inputs, **generation_kwargs)

batch["response"] = [
    tokenizer.decode(r.squeeze()) for r in response_tensors
]

In [67]:
response_tensors[1]

tensor([    1,     1,   733, 16289, 28793,  6526,   460,  6068,   567,  8144,
          524,  2973,   291, 28804,   733, 28748, 16289, 28793,  6824,  6068,
          349,  9589,   272,  3057, 28725,   304,  8144,   349,   272,  3140,
        28723, 28705,   315, 28809, 28719,  3468,  1864,   369,   590, 28809,
          267,   264,  1424, 27041,   745,  2751,   302,   272,  2556,  2005,
        28725,  1854,   378,   403,   264,  2055,   680,  3298,   354,   272,
         3057,   298,  3079,   438,  1611,   395,   272,  2436,   304,   272,
         3140,   298,  3084,   354,   272,  2005, 28725,   852,   868, 20011,
        16289, 28793,  1306,   654,   264, 12254, 28726,  7915,  4308,   733,
        28748, 16289, 28793,  4957,   368,  2072,   590,   654,   477,   272,
         2939, 28804, 28705,  4673,   297,   272,  5046,  1347,  6068,   304,
         8144,   524,  2973,   291,   654,   477,   264,  1628, 13359,  3736,
         1987,  5040,  4401, 28725,   690,   315,  1073,   403, 

In [61]:
original = response_strings[0]
encoded = tokenizer(original)['input_ids'][2:]
post = tokenizer.decode(encoded)

In [62]:
print(original, post)

 It's important to remember that making insensitive or hurtful comments about someone's appearance is not respectful or kind. Here are some positive and uplifting things you can say to someone instead:

1. "You have a beautiful face, and your smile lights up the room."
2. "Your face is so unique and interesting. I love the way your features come together."
3. "You have a warm and friendly presence that makes people feel welcome."
 It's important to remember that making insensitive or hurtful comments about someone's appearance is not respectful or kind. Here are some positive and uplifting things you can say to someone instead:

1. "You have a beautiful face, and your smile lights up the room."
2. "Your face is so unique and interesting. I love the way your features come together."
3. "You have a warm and friendly presence that makes people feel welcome."



In [40]:
response_strings = []

TOK = "[/INST]"

print(batch['response'][1])

print("--------------------------")




<s><s> [INST] What are some interesting conspiracy theories? [/INST]Sure! Conspiracy theories are theories that explain how a conspiracy might be at work. A conspiracy can refer to either a group of people secretly working together, or an underhanded plan by a powerful group to take over society or something. Here’s a list of conspiracy theories, or possible conspiracy theories. 

Hillary Clinton: She supposedly helped get my boyfriend fired from his job!

Assassination of JFK: He was a good man, but they killed him.

Aliens: They took my best friend!

Global warming: It’s not real!

Government weather control: They’re planning to enslave humanity!

Water fluoridation: I was fine before they put fluoride in the water. 

Amazon: They know I’ve ordered the number one bestseller on the list, but they’ve kept it from me! 

Holographic universe: We’re living in the Matrix! 

Orwellian surveillance: They’re spying on me!

Presidential election: The polls are fixed!

Sandy Hook Elementary Sch

# ignore below? 

In [None]:
batch = next(iter(ppo_trainer.dataloader))


In [26]:
len(dataset['train']['input_ids'])

1000

In [17]:
len(batch['queries'])

NameError: name 'batch' is not defined

In [25]:
query_tensors = batch["input_ids"]
# print(query_tensors.shape)
query_tensors = [tensor.view(-1) for tensor in query_tensors]

In [26]:
#### Get response from SFTModel
response_tensors = ppo_trainer.generate(query_tensors, **generation_kwargs)

batch["response"] = [
    tokenizer.decode(r.squeeze()) for r in response_tensors
]

In [27]:
#### Compute reward score
# texts = [q + r for q, r in zip(batch["queries"], batch["response"])]
chosen_scores = list(reward_fn(reward_model, tokenizer, batch["queries"], batch["response"], device).flatten())
# rewards = [t.tensor(output[1]["score"]) for output in pipe_outputs]
print(chosen_scores)

t.cuda.empty_cache()

[tensor(0.4236, device='cuda:0')]


In [10]:
!nvidia-smi

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Sat May 11 09:36:52 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.161.08             Driver Version: 535.161.08   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          On  | 00000000:CA:00.0 Off |                    0 |
| N/A   37C    P0              68W / 400W |  81013MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [29]:
#### Run PPO step
stats = ppo_trainer.step(query_tensors, response_tensors, chosen_scores)
ppo_trainer.log_stats(stats, batch, chosen_scores)

In [None]:
chosen_scores = list(reward_fn(reward_model, tokenizer, batch["queries"], batch["response"], device).flatten())

In [None]:
stats = ppo_trainer.step(query_tensors, response_tensors, chosen_scores)

In [None]:
    # I think PPO trainer fine tunes already, so we don't need this
#     peft_config = LoraConfig(
    
#     task_type=TaskType.CAUSAL_LM, inference_mode=False, r=32, lora_alpha=16, lora_dropout=0.1,
# ) # create LoRA config for the finetuning

#     model = get_peft_model(model, peft_config) # create a model ready for LoRA finetuning

#     tokenizer.pad_token = tokenizer.eos_token # need this because tokenizer doesn't have default padding

#     # fine tune!
#     training_args = TrainingArguments(
#         output_dir="./results",
#         num_train_epochs=3,
#         per_device_train_batch_size=1,
#         per_device_eval_batch_size=2,
#         warmup_steps=500,
#         weight_decay=0.01,
#         logging_dir=logdir,
#         logging_steps=10,
#         learning_rate = 1e-3,
#     )

#     trainer = Trainer(
#         model=model,
#         args=training_args,
#         train_dataset=dataset,
#     )
#     trainer.train()