In [1]:
import warnings
warnings.filterwarnings('ignore')
import os
import sys
import time

In [2]:
os.environ['TRANSFORMERS_OFFLINE']="1"

In [3]:
os.environ['TRANSFORMERS_CACHE'] = '/scratch/shareddata/dldata/huggingface-hub-cache/hub'

In [4]:
import torch
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
from transformers import AutoTokenizer, AutoModelForCausalLM

## Generate responds

In [None]:
model_name = "meta-llama/Llama-2-7b-chat-hf"
# model_name = "gpt2"

In [None]:
tokenizer = tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token_id = tokenizer.eos_token_id

In [None]:
model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name,
                                                         # load_in_8bit=True,
                                                  # torch_dtype=torch.float16,
                                                 device_map="auto")
# model_ref = AutoModelForCausalLMWithValueHead.from_pretrained(model_name,torch_dtype=torch.float16)

In [None]:
device = model.pretrained_model.device

In [None]:
# Initialize an empty list to store the sentences
prompts_list = []

# Open the text file
with open('prompts.txt', 'r') as file:
    # Read each line in the file
    for line in file:
        # Add the line (sentence) to the list
        prompts_list.append(line.strip())

# Now 'sentences' contains all the sentences from the file
print( prompts_list[:10])
print(len(prompts_list))

In [None]:
batch_size = 2
prompt_batches = [prompts_list[i:i + batch_size] for i in range(0, len(prompts_list), batch_size)]
# print(prompt_batches)

In [None]:
all_responds = []
for batch in prompt_batches:

    # Tokenize the batch
    inputs = tokenizer(batch, padding=True, truncation=True,max_length=30, return_tensors="pt").to(device)

    # Generate outputs
    outputs = model.generate(**inputs)
    
    # Decode the outputs
    responds = tokenizer.batch_decode(outputs, skip_special_tokens=True)

    # Collect all generated texts
    all_responds.extend(responds)


In [None]:
len(all_responds)

In [5]:
from datasets import Dataset

# data_dict = {
#     "prompt": prompts_list,  # List of all prompts
#     "response": all_responds  # Corresponding generated texts
# }

# dataset = Dataset.from_dict(data_dict)
# dataset.save_to_disk("./completion_llama2_7b_chat")

# For native format
loaded_dataset = Dataset.load_from_disk("./completion_llama2_7b_chat")

## Get rewards

In [7]:
from transformers import pipeline
classifier = pipeline("zero-shot-classification",
                      model="facebook/bart-large-mnli",
                      device=0)

In [8]:
sequences_to_classify = loaded_dataset['response']
candidate_labels = ['positive', 'negative']
results = classifier(sequences_to_classify, candidate_labels)


In [13]:
results[:2]

[{'sequence': 'As the sun dipped below the horizon, she whispered Sheila\'s name into the darkness.\n\nThe wind rustled through the trees, carrying the sound of her voice to the ears of the sleeping girl. She stirred, her eyelids flickering open, and she saw her mother standing over her, her face bathed in the soft, golden light of the setting sun.\n\n"Mother?" Sheila said, her voice barely above a whisper.\n\n"Shh, my child," her mother replied, her voice low and soothing. "I\'m here, and I\'ll always be here for you."\n\nSheila smiled, feeling a sense of peace wash over her. She knew that her mother would always be there to protect her, to guide her, and to love her. And as the stars began to twinkle in the sky, she drifted off to sleep, surrounded by the warmth and love of her mother\'s embrace.',
  'labels': ['positive', 'negative'],
  'scores': [0.8609581589698792, 0.13904178142547607]},
 {'sequence': 'In the heart of the bustling city, the old clock tower stood tall and proud, a 

In [14]:
rewards = [result['scores'][0] for result in results]
rewards

[0.8609581589698792,
 0.9228175282478333,
 0.5317795276641846,
 0.8464369773864746,
 0.8121612668037415,
 0.8203597664833069,
 0.9301291704177856,
 0.8936611413955688,
 0.6091729998588562,
 0.7915114164352417,
 0.8447042107582092,
 0.5947414636611938,
 0.9344220161437988,
 0.7813530564308167,
 0.6773236989974976,
 0.8200059533119202,
 0.5995572209358215,
 0.757203221321106,
 0.9054793119430542,
 0.6570038795471191,
 0.6353925466537476,
 0.8420090675354004,
 0.9306687712669373,
 0.581459105014801,
 0.5401076674461365,
 0.8723202347755432,
 0.7645230889320374,
 0.8475698828697205,
 0.9527148604393005,
 0.611027181148529,
 0.5725876092910767,
 0.9297681450843811,
 0.7758680582046509,
 0.8660714030265808,
 0.7207587361335754,
 0.6905563473701477,
 0.7340672016143799,
 0.7655211687088013,
 0.8030493259429932,
 0.8590236306190491,
 0.6460028290748596,
 0.5988184809684753,
 0.9237436056137085,
 0.696963906288147,
 0.853239119052887,
 0.9171207547187805,
 0.7990365028381348,
 0.776908934116363