In [1]:
# !pip install transformers
# !pip install accelerate
# !pip install twine
# !pip install datasets
# !pip install tyro

In [2]:
!nvidia-smi

Wed Mar 20 11:37:05 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02              Driver Version: 530.30.02    CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                  Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla V100-PCIE-32GB            On | 00000000:3B:00.0 Off |                  Off |
| N/A   31C    P0               24W / 250W|      0MiB / 32768MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                         

In [3]:
# !pip install wandb
# !pip install trl
# !pip install pandas
# !pip install datasets
# !pip install nltk -U

In [4]:
import torch
from tqdm import tqdm
import pandas as pd
import wandb
import os

tqdm.pandas()

from transformers import pipeline, AutoTokenizer
from datasets import load_dataset

from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from trl.core import LengthSampler

In [5]:
config = PPOConfig(
    model_name    = "openai-community/gpt2",
    learning_rate = 1.41e-5,
    ## log_with      = "wandb",
)

sent_kwargs = {
         "return_all_scores": True, 
         "function_to_apply": "none", 
         "batch_size": 16
}
print(config)

PPOConfig(exp_name='ipykernel_launcher', seed=0, log_with=None, task_name=None, model_name='openai-community/gpt2', query_dataset='imdb', reward_model='sentiment-analysis:lvwerra/distilbert-imdb', remove_unused_columns=True, tracker_kwargs={}, accelerator_kwargs={}, project_kwargs={}, tracker_project_name='trl', push_to_hub_if_best_kwargs={}, steps=20000, learning_rate=1.41e-05, adap_kl_ctrl=True, init_kl_coef=0.2, kl_penalty='kl', target=6, horizon=10000, gamma=1, lam=0.95, cliprange=0.2, cliprange_value=0.2, vf_coef=0.1, batch_size=128, forward_batch_size=None, mini_batch_size=128, gradient_accumulation_steps=1, world_size=None, ppo_epochs=4, max_grad_norm=None, optimize_cuda_cache=None, optimize_device_cache=False, early_stopping=False, target_kl=1, compare_steps=1, ratio_threshold=10.0, use_score_scaling=False, use_score_norm=False, score_clip=None, whiten_rewards=False, is_encoder_decoder=None, is_peft_model=None, backward_batch_size=128, global_backward_batch_size=None, global_ba

In [6]:
## wandb.init()

wandb.init(mode="disabled") 
os.environ['WANDB_DISABLED'] = 'true'

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


In [7]:
dataset_name="ag_news"

In [8]:
ds = load_dataset(dataset_name, split = "train[:100000]")

In [9]:
ds

Dataset({
    features: ['text', 'label'],
    num_rows: 100000
})

In [10]:
from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML

In [11]:
tokenizer           = AutoTokenizer.from_pretrained(config.model_name)
tokenizer.pad_token = tokenizer.eos_token
# print(tokenizer)

In [12]:
def tokenize( sample ):
    sample["input_ids"] = tokenizer.encode( sample["text"]    )[: 100]
    sample["query"]     = tokenizer.decode( sample["input_ids"] )
    return sample

#print(tokenize)
ds = ds.map(tokenize, batched=False)


Map:   0%|          | 0/100000 [00:00<?, ? examples/s]

In [13]:
print(ds)

Dataset({
    features: ['text', 'label', 'input_ids', 'query'],
    num_rows: 100000
})


In [14]:
def build_dataset(
         config, 
         dataset_name="ag_news", 
         input_min_text_length=2, 
         input_max_text_length=8
):
    """
    Build dataset for training. This builds the dataset from `load_dataset`, one should
    customize this function to train the model on its own dataset.

    Args:
        dataset_name (`str`):
            The name of the dataset to be loaded.

    Returns:
        dataloader (`torch.utils.data.DataLoader`):
            The dataloader for the dataset.
    """
    tokenizer           = AutoTokenizer.from_pretrained(config.model_name)
    tokenizer.pad_token = tokenizer.eos_token
    
    # load with datasets
    
    ds = load_dataset(dataset_name, split="train[:100000]")
    
#     ds = ds.rename_columns({"text": "review"})
    ds = ds.filter(lambda x: len(x["text"]) > 100, batched=False)

    input_size = LengthSampler(input_min_text_length, input_max_text_length)

    def tokenize(sample):
        sample["input_ids"] = tokenizer.encode( sample["text"]    )[: input_size()]
        sample["query"]     = tokenizer.decode( sample["input_ids"] )
        return sample

    ds = ds.map(tokenize, batched=False)
    ds.set_format(type="torch")
    return ds

In [15]:
dataset = build_dataset(config)

Filter:   0%|          | 0/100000 [00:00<?, ? examples/s]

Map:   0%|          | 0/99939 [00:00<?, ? examples/s]

In [16]:
dataset

Dataset({
    features: ['text', 'label', 'input_ids', 'query'],
    num_rows: 99939
})

In [17]:
def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

In [18]:
model     = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)

tokenizer = AutoTokenizer.from_pretrained(config.model_name)

tokenizer.pad_token = tokenizer.eos_token

In [19]:
ppo_trainer = PPOTrainer(
                 config, 
                 model, 
                 ref_model, 
                 tokenizer, 
                 dataset=dataset, 
                 data_collator=collator
)

Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [20]:
device = ppo_trainer.accelerator.device
device

device(type='cuda')

In [21]:
if ppo_trainer.accelerator.num_processes == 1:
    device = 0 if torch.cuda.is_available() else "cpu"  # to avoid a `pipeline` bug

device

0

In [22]:
# Use a pipeline as a high-level helper
from transformers import pipeline

pipe = pipeline("text-classification", model="wesleyacheng/news-topic-classification-with-bert")

In [23]:
generation_kwargs = {
    "min_length":     -1,
    "top_k":         0.0,
    "top_p":         1.0,
    "do_sample":    True,
    "pad_token_id": tokenizer.eos_token_id,
}

In [24]:
output_min_length     = 4
output_max_length     = 16
output_length_sampler = LengthSampler(output_min_length, output_max_length)

In [25]:
ppo_trainer.config.steps

20000

In [26]:
# import os
# from transformers import GPT2LMHeadModel

# def save_checkpoint(model, filepath):
#     model.save_pretrained(filepath)
#     print("Model checkpoint saved successfully.")

# # Load Model Checkpoint
# def load_checkpoint(filepath):
#     if os.path.isdir(filepath):
#         model = GPT2LMHeadModel.from_pretrained(filepath)
#         print("Model checkpoint loaded successfully.")
#     else:
#         raise FileNotFoundError(f"Checkpoint directory not found at {filepath}")
#     return model



In [27]:
# import torch

# def gpu_memory_almost_full(threshold=0.9):
#     """
#     Check if GPU memory is almost full.
    
#     Args:
#         threshold (float): Threshold percentage for GPU memory usage.
#             If the current GPU memory usage exceeds this threshold, 
#             consider the GPU memory almost full. Default is 0.9 (90%).
    
#     Returns:
#         bool: True if GPU memory is almost full, False otherwise.
#     """
#     allocated_bytes = torch.cuda.memory_allocated()
#     total_bytes = torch.cuda.get_device_properties(0).total_memory  # assuming GPU 0
#     utilization = allocated_bytes / total_bytes
#     return utilization >= threshold

In [28]:
# checkpoint_dir = '/home/vemuri8/gpt2_checkpoint'
# if os.path.exists(checkpoint_dir):
#     model = load_checkpoint(checkpoint_dir)
# else:
#     # Replace "model_name" with the actual model name or identifier
#     model = GPT2LMHeadModel.from_pretrained("gpt2_checkpoint")

In [None]:
for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    query_tensors = batch["input_ids"]
    print(epoch)

    #### Get response from gpt2
    response_tensors = []
    for query in query_tensors:
        gen_len                             = output_length_sampler()
        generation_kwargs["max_new_tokens"] = gen_len
        response                            = ppo_trainer.generate(query, **generation_kwargs)
        response_tensors.append( response.squeeze()[-gen_len:] )
    batch["response"] = [ tokenizer.decode(r.squeeze()) for r in response_tensors ]

    #### Compute sentiment score
    texts = [q + r for q, r in zip(batch["query"], batch["response"])]
    pipe_outputs = pipe(texts, **sent_kwargs)
    rewards = [ torch.tensor(output[1]["score"]) for output in pipe_outputs]

    #### Run PPO step
    stats = ppo_trainer.step(
                     query_tensors, 
                     response_tensors, 
                     rewards
    )
#     ppo_trainer.log_stats(stats, batch, rewards)
#     if gpu_memory_almost_full():
#     save_checkpoint(model, 'gpt2_checkpoint')
#         break



0it [00:00, ?it/s]

0


1it [00:19, 19.67s/it]

1


2it [00:40, 20.23s/it]

2


3it [00:59, 19.79s/it]

3


4it [01:20, 20.37s/it]

4


5it [01:40, 20.18s/it]

5


6it [02:00, 20.16s/it]

6


7it [02:20, 19.91s/it]

7


8it [02:38, 19.32s/it]

8


9it [02:58, 19.62s/it]

9


10it [03:17, 19.53s/it]

10


11it [03:37, 19.53s/it]

11


12it [03:56, 19.34s/it]

12


13it [04:15, 19.22s/it]

13


14it [04:33, 19.01s/it]

14


15it [04:54, 19.57s/it]

15


16it [05:13, 19.33s/it]

16


17it [05:33, 19.60s/it]

17


18it [05:53, 19.66s/it]

18


19it [06:14, 20.11s/it]

19


20it [06:33, 19.86s/it]

20


21it [06:52, 19.52s/it]

21


22it [07:12, 19.56s/it]

22


23it [07:34, 20.23s/it]

23


24it [07:53, 20.14s/it]

24


25it [08:13, 19.93s/it]

25


26it [08:37, 21.15s/it]

26


27it [08:59, 21.52s/it]

27


28it [09:19, 21.10s/it]

28


29it [09:38, 20.40s/it]

29


30it [09:57, 19.95s/it]

30


31it [10:16, 19.57s/it]

31


32it [10:35, 19.57s/it]

32


33it [10:55, 19.51s/it]

33


34it [11:14, 19.58s/it]

34


35it [11:34, 19.45s/it]

35


36it [11:52, 19.21s/it]

36


37it [12:11, 19.19s/it]

37


38it [12:31, 19.21s/it]

38


39it [12:54, 20.30s/it]

39


40it [13:16, 20.95s/it]

40


41it [13:36, 20.62s/it]

41


42it [13:55, 20.16s/it]

42


43it [14:14, 19.95s/it]

43


44it [14:33, 19.58s/it]

44


45it [14:52, 19.43s/it]

45


46it [15:13, 19.91s/it]

46


47it [15:33, 19.95s/it]

47


48it [15:53, 19.79s/it]

48


49it [16:11, 19.25s/it]

49


50it [16:29, 18.96s/it]

50


51it [16:48, 18.86s/it]

51


52it [17:07, 18.98s/it]

52


53it [17:26, 19.10s/it]

53


54it [17:46, 19.30s/it]

54


55it [18:05, 19.31s/it]

55


56it [18:24, 19.22s/it]

56


57it [18:43, 19.10s/it]

57


58it [19:03, 19.30s/it]

58


59it [19:22, 19.11s/it]

59


60it [19:41, 19.18s/it]

60


61it [20:00, 19.31s/it]

61


62it [20:20, 19.30s/it]

62


63it [20:38, 19.08s/it]

63


64it [20:58, 19.33s/it]

64


65it [21:18, 19.31s/it]

65


66it [21:38, 19.79s/it]

66


67it [21:58, 19.76s/it]

67


68it [22:17, 19.63s/it]

68


69it [22:38, 19.79s/it]

69


70it [22:57, 19.57s/it]

70


71it [23:16, 19.52s/it]

71


72it [23:36, 19.50s/it]

72


73it [23:55, 19.60s/it]

73


74it [24:15, 19.50s/it]

74


75it [24:36, 20.17s/it]

75


76it [24:56, 19.92s/it]

76


77it [25:16, 19.94s/it]

77


78it [25:35, 19.83s/it]

78


79it [25:54, 19.57s/it]

79


80it [26:13, 19.33s/it]

80


81it [26:33, 19.54s/it]

81


82it [26:52, 19.27s/it]

82


83it [27:12, 19.59s/it]

83


84it [27:31, 19.34s/it]

84


85it [27:50, 19.41s/it]

85


86it [28:10, 19.42s/it]

86


87it [28:29, 19.39s/it]

87


88it [28:48, 19.27s/it]

88


89it [29:07, 19.24s/it]

89


90it [29:27, 19.37s/it]

90


91it [29:47, 19.53s/it]

91


92it [30:06, 19.46s/it]

92


93it [30:25, 19.35s/it]

93


94it [30:45, 19.45s/it]

94


95it [31:04, 19.40s/it]

95


96it [31:24, 19.61s/it]

96


97it [31:44, 19.61s/it]

97


98it [32:04, 19.73s/it]

98


99it [32:23, 19.62s/it]

99


100it [32:43, 19.65s/it]

100


101it [33:01, 19.31s/it]

101


102it [33:21, 19.28s/it]

102


103it [33:39, 19.08s/it]

103


104it [33:58, 18.99s/it]

104


105it [34:17, 19.02s/it]

105


106it [34:36, 19.02s/it]

106


107it [34:56, 19.23s/it]

107


108it [35:15, 19.24s/it]

108


109it [35:35, 19.45s/it]

109


110it [35:55, 19.69s/it]

110


111it [36:16, 19.95s/it]

111


112it [36:36, 19.87s/it]

112


113it [36:55, 19.79s/it]

113


114it [37:15, 19.88s/it]

114


115it [37:35, 19.87s/it]

115


116it [37:54, 19.48s/it]

116


117it [38:13, 19.38s/it]

117


118it [38:32, 19.29s/it]

118


119it [38:51, 19.31s/it]

119


120it [39:11, 19.56s/it]

120


121it [39:31, 19.43s/it]

121


122it [39:49, 19.07s/it]

122


123it [40:08, 19.16s/it]

123


124it [40:28, 19.22s/it]

124


125it [40:47, 19.22s/it]

125


126it [41:06, 19.35s/it]

126


127it [41:26, 19.37s/it]

127


128it [41:46, 19.57s/it]

128


129it [42:05, 19.43s/it]

129


130it [42:24, 19.45s/it]

130


131it [42:46, 20.12s/it]

131


132it [43:06, 20.04s/it]

132


133it [43:25, 19.86s/it]

133


134it [43:45, 19.88s/it]

134


135it [44:04, 19.62s/it]

135


136it [44:24, 19.75s/it]

136


137it [44:44, 19.72s/it]

137


138it [45:02, 19.20s/it]

138


139it [45:22, 19.32s/it]

139


140it [45:41, 19.27s/it]

140


141it [46:00, 19.37s/it]

141


142it [46:20, 19.55s/it]

142


143it [46:40, 19.54s/it]

143


144it [46:59, 19.42s/it]

144


145it [47:18, 19.30s/it]

145


146it [47:37, 19.32s/it]

146


147it [47:57, 19.25s/it]

147


148it [48:16, 19.39s/it]

148


149it [48:36, 19.56s/it]

149


150it [48:56, 19.50s/it]

150


151it [49:16, 19.73s/it]

151


152it [49:35, 19.61s/it]

152


153it [49:54, 19.44s/it]

153


154it [50:14, 19.48s/it]

154


155it [50:32, 19.23s/it]

155


156it [50:52, 19.27s/it]

156


157it [51:11, 19.22s/it]

157


158it [51:30, 19.18s/it]

158


159it [51:49, 19.15s/it]

159


160it [52:12, 20.39s/it]

160


161it [52:31, 19.78s/it]

161


162it [52:50, 19.71s/it]

162


163it [53:10, 19.59s/it]

163


164it [53:30, 19.96s/it]

164


165it [53:50, 19.82s/it]

165


166it [54:09, 19.57s/it]

166


167it [54:28, 19.44s/it]

167


168it [54:48, 19.52s/it]

168


169it [55:07, 19.52s/it]

169


170it [55:26, 19.22s/it]

170


171it [55:45, 19.12s/it]

171


172it [56:04, 19.34s/it]

172


173it [56:24, 19.26s/it]

173


174it [56:42, 19.02s/it]

174


175it [57:02, 19.27s/it]

175


176it [57:21, 19.16s/it]

176


177it [57:41, 19.38s/it]

177


178it [57:59, 19.12s/it]

178


179it [58:21, 19.81s/it]

179


180it [58:41, 19.88s/it]

180


181it [59:02, 20.24s/it]

181


182it [59:21, 19.95s/it]

182


183it [59:40, 19.71s/it]

183


184it [1:00:00, 19.78s/it]

184


185it [1:00:19, 19.53s/it]

185


186it [1:00:38, 19.25s/it]

186


187it [1:00:57, 19.24s/it]

187


188it [1:01:16, 19.33s/it]

188


189it [1:01:35, 19.13s/it]

189


190it [1:01:54, 19.11s/it]

190


191it [1:02:13, 19.08s/it]

191


192it [1:02:33, 19.17s/it]

192


193it [1:02:51, 19.08s/it]

193


194it [1:03:10, 18.93s/it]

194


195it [1:03:29, 19.02s/it]

195


196it [1:03:48, 18.99s/it]

196


197it [1:04:07, 18.95s/it]

197


198it [1:04:26, 19.05s/it]

198


199it [1:04:46, 19.32s/it]

199


200it [1:05:05, 19.10s/it]

200


201it [1:05:24, 19.21s/it]

201


202it [1:05:43, 19.21s/it]

202


203it [1:06:04, 19.48s/it]

203


204it [1:06:25, 19.93s/it]

204


205it [1:06:45, 19.95s/it]

205


206it [1:07:03, 19.61s/it]

206


207it [1:07:24, 19.81s/it]

207


208it [1:07:44, 19.93s/it]

208


209it [1:08:03, 19.63s/it]

209


210it [1:08:22, 19.64s/it]

210


211it [1:08:41, 19.35s/it]

211


212it [1:08:59, 18.94s/it]

212


213it [1:09:21, 19.96s/it]

213


214it [1:09:43, 20.37s/it]

214


215it [1:10:02, 19.98s/it]

215


216it [1:10:21, 19.86s/it]

216


217it [1:10:41, 19.90s/it]

217


218it [1:11:01, 19.88s/it]

218


219it [1:11:21, 19.97s/it]

219


220it [1:11:42, 20.01s/it]

220


221it [1:12:02, 20.24s/it]

221


222it [1:12:21, 19.80s/it]

222


223it [1:12:42, 20.21s/it]

223


224it [1:13:02, 20.03s/it]

224


225it [1:13:22, 20.07s/it]

225


226it [1:13:41, 19.72s/it]

226


227it [1:14:01, 19.87s/it]

227


228it [1:14:20, 19.54s/it]

228


229it [1:14:39, 19.51s/it]

229


230it [1:14:58, 19.33s/it]

230


231it [1:15:18, 19.32s/it]

231


232it [1:15:38, 19.56s/it]

232


233it [1:15:56, 19.27s/it]

233


234it [1:16:17, 19.58s/it]

234


235it [1:16:35, 19.15s/it]

235


236it [1:16:53, 18.87s/it]

236


237it [1:17:12, 18.94s/it]

237


238it [1:17:31, 19.02s/it]

238


239it [1:17:49, 18.78s/it]

239


240it [1:18:08, 18.79s/it]

240


241it [1:18:28, 18.95s/it]

241


242it [1:18:47, 19.01s/it]

242


243it [1:19:06, 18.98s/it]

243


244it [1:19:25, 19.16s/it]

244


245it [1:19:45, 19.26s/it]

245


246it [1:20:04, 19.26s/it]

246


247it [1:20:22, 19.02s/it]

247


248it [1:20:41, 19.01s/it]

248


249it [1:21:00, 19.01s/it]

249


250it [1:21:20, 19.10s/it]

250


251it [1:21:40, 19.32s/it]

251


252it [1:21:59, 19.25s/it]

252


253it [1:22:18, 19.24s/it]

253


254it [1:22:36, 18.96s/it]

254


255it [1:22:56, 19.26s/it]

255


256it [1:23:15, 19.19s/it]

256


257it [1:23:34, 19.03s/it]

257


258it [1:23:54, 19.23s/it]

258


259it [1:24:12, 19.11s/it]

259


260it [1:24:32, 19.35s/it]

260


261it [1:24:52, 19.34s/it]

261


262it [1:25:11, 19.48s/it]

262


263it [1:25:30, 19.28s/it]

263


264it [1:25:49, 19.23s/it]

264


265it [1:26:08, 19.01s/it]

265


266it [1:26:27, 18.97s/it]

266


267it [1:26:47, 19.23s/it]

267


268it [1:27:06, 19.19s/it]

268


269it [1:27:25, 19.10s/it]

269


270it [1:27:44, 19.06s/it]

270


271it [1:28:03, 19.24s/it]

271


272it [1:28:23, 19.54s/it]

272


273it [1:28:43, 19.46s/it]

273


274it [1:29:01, 19.16s/it]

274


275it [1:29:21, 19.34s/it]

275


276it [1:29:40, 19.29s/it]

276


277it [1:30:00, 19.51s/it]

277


278it [1:30:19, 19.27s/it]

278


279it [1:30:38, 19.36s/it]

279


280it [1:30:58, 19.39s/it]

280


281it [1:31:17, 19.27s/it]

281


282it [1:31:37, 19.49s/it]

282


283it [1:31:57, 19.54s/it]

283


284it [1:32:16, 19.39s/it]

284


285it [1:32:35, 19.48s/it]

285


286it [1:32:54, 19.35s/it]

286


287it [1:33:13, 19.22s/it]

287


288it [1:33:33, 19.38s/it]

288


289it [1:33:52, 19.28s/it]

289


290it [1:34:11, 19.22s/it]

290


291it [1:34:31, 19.39s/it]

291


292it [1:34:50, 19.36s/it]

292


293it [1:35:10, 19.54s/it]

293


294it [1:35:29, 19.40s/it]

294


295it [1:35:48, 19.35s/it]

295


296it [1:36:07, 19.15s/it]

296


297it [1:36:27, 19.31s/it]

297


298it [1:36:46, 19.23s/it]

298


299it [1:37:05, 19.19s/it]

299


300it [1:37:24, 19.19s/it]

300


301it [1:37:43, 19.06s/it]

301


302it [1:38:02, 18.98s/it]

302


303it [1:38:21, 19.01s/it]

303


304it [1:38:39, 18.87s/it]

304


305it [1:38:59, 19.09s/it]

305


306it [1:39:17, 18.77s/it]

306


307it [1:39:37, 19.03s/it]

307


308it [1:39:56, 19.21s/it]

308


309it [1:40:15, 19.01s/it]

309


310it [1:40:34, 19.05s/it]

310


311it [1:40:53, 19.17s/it]

311


312it [1:41:12, 19.10s/it]

312


313it [1:41:32, 19.35s/it]

313


314it [1:41:50, 19.03s/it]

314


315it [1:42:10, 19.29s/it]

315


316it [1:42:30, 19.49s/it]

316


317it [1:42:49, 19.27s/it]

317


318it [1:43:08, 19.15s/it]

318


319it [1:43:27, 19.08s/it]

319


320it [1:43:46, 19.14s/it]

320


321it [1:44:05, 18.95s/it]

321


322it [1:44:24, 19.17s/it]

322


323it [1:44:44, 19.42s/it]

323


324it [1:45:03, 19.09s/it]

324


325it [1:45:22, 19.03s/it]

325


326it [1:45:40, 18.99s/it]

326


327it [1:45:59, 18.88s/it]

327


328it [1:46:18, 18.84s/it]

328


329it [1:46:37, 18.90s/it]

329


330it [1:46:55, 18.66s/it]

330


331it [1:47:14, 18.64s/it]

331


332it [1:47:33, 19.03s/it]

332


333it [1:47:53, 19.09s/it]

333


334it [1:48:11, 18.99s/it]

334


335it [1:48:31, 19.17s/it]

335


336it [1:48:50, 19.09s/it]

336


337it [1:49:09, 19.15s/it]

337


338it [1:49:31, 20.02s/it]

338


339it [1:49:54, 20.87s/it]

339


340it [1:50:17, 21.35s/it]

340


341it [1:50:40, 21.84s/it]

341


342it [1:51:00, 21.36s/it]

342


343it [1:51:20, 20.87s/it]

343


344it [1:51:39, 20.40s/it]

344


345it [1:51:59, 20.31s/it]

345


346it [1:52:22, 21.01s/it]

346


347it [1:52:44, 21.52s/it]

347


348it [1:53:09, 22.60s/it]

348


349it [1:53:29, 21.74s/it]

349


350it [1:53:49, 21.09s/it]

350


351it [1:54:08, 20.52s/it]

351


352it [1:54:27, 20.19s/it]

352


353it [1:54:46, 19.73s/it]

353


354it [1:55:05, 19.48s/it]

354


355it [1:55:23, 19.21s/it]

355


356it [1:55:43, 19.17s/it]

356


357it [1:56:01, 19.02s/it]

357


358it [1:56:20, 18.94s/it]

358


359it [1:56:39, 19.01s/it]

359


360it [1:56:58, 18.95s/it]

360


361it [1:57:17, 18.88s/it]

361


362it [1:57:35, 18.85s/it]

362


363it [1:57:55, 18.93s/it]

363


364it [1:58:13, 18.86s/it]

364


365it [1:58:31, 18.64s/it]

365


366it [1:58:51, 18.91s/it]

366


367it [1:59:11, 19.21s/it]

367


368it [1:59:30, 19.23s/it]

368


369it [1:59:49, 19.24s/it]

369


370it [2:00:09, 19.42s/it]

370


371it [2:00:28, 19.32s/it]

371


372it [2:00:48, 19.38s/it]

372


373it [2:01:10, 20.15s/it]

373


374it [2:01:29, 19.87s/it]

374


375it [2:01:48, 19.70s/it]

375


376it [2:02:08, 19.74s/it]

376


377it [2:02:28, 19.82s/it]

377


378it [2:02:47, 19.50s/it]

378


379it [2:03:07, 19.61s/it]

379


380it [2:03:25, 19.26s/it]

380


381it [2:03:44, 19.14s/it]

381


382it [2:04:03, 19.15s/it]

382


383it [2:04:22, 19.05s/it]

383


384it [2:04:41, 19.06s/it]

384


385it [2:05:00, 19.13s/it]

385


386it [2:05:19, 19.02s/it]

386


387it [2:05:38, 19.06s/it]

387


388it [2:05:58, 19.22s/it]

388


389it [2:06:17, 19.16s/it]

389


390it [2:06:36, 19.25s/it]

390


391it [2:06:56, 19.30s/it]

391


392it [2:07:15, 19.16s/it]

392


393it [2:07:34, 19.33s/it]

393


394it [2:07:53, 19.11s/it]

394


395it [2:08:13, 19.24s/it]

395


396it [2:08:32, 19.43s/it]

396


397it [2:08:51, 19.22s/it]

397


398it [2:09:11, 19.27s/it]

398


399it [2:09:31, 19.63s/it]

399


400it [2:09:50, 19.42s/it]

400


401it [2:10:10, 19.62s/it]

401


402it [2:10:29, 19.51s/it]

402


403it [2:10:49, 19.58s/it]

403


404it [2:11:08, 19.39s/it]

404


405it [2:11:27, 19.23s/it]

405


406it [2:11:46, 19.13s/it]

406


407it [2:12:05, 19.32s/it]

407


408it [2:12:25, 19.40s/it]

408


409it [2:12:44, 19.29s/it]

409


410it [2:13:03, 19.05s/it]

410


411it [2:13:21, 18.94s/it]

411


412it [2:13:41, 19.06s/it]

412


413it [2:14:00, 19.27s/it]

413


414it [2:14:19, 19.22s/it]

414


415it [2:14:38, 19.13s/it]

415


416it [2:14:58, 19.19s/it]

416


417it [2:15:17, 19.19s/it]

417


418it [2:15:35, 18.95s/it]

418


419it [2:15:54, 18.82s/it]

419


420it [2:16:13, 18.97s/it]

420


421it [2:16:32, 18.99s/it]

421


422it [2:16:51, 18.93s/it]

422


423it [2:17:10, 19.05s/it]

423


424it [2:17:29, 18.87s/it]

424


425it [2:17:49, 19.18s/it]

425


426it [2:18:07, 18.86s/it]

426


427it [2:18:25, 18.82s/it]

427


428it [2:18:44, 18.77s/it]

428


429it [2:19:03, 18.77s/it]

429


430it [2:19:22, 19.00s/it]

430


431it [2:19:44, 19.88s/it]

431


432it [2:20:04, 19.93s/it]

432


433it [2:20:24, 19.79s/it]

433


434it [2:20:43, 19.73s/it]

434


435it [2:21:03, 19.66s/it]

435


436it [2:21:22, 19.45s/it]

436


437it [2:21:41, 19.39s/it]

437


438it [2:22:01, 19.41s/it]

438


439it [2:22:20, 19.26s/it]

439


440it [2:22:39, 19.23s/it]

440


441it [2:22:58, 19.20s/it]

441


442it [2:23:17, 19.05s/it]

442


443it [2:23:36, 19.10s/it]

443


444it [2:23:54, 18.97s/it]

444


445it [2:24:14, 19.06s/it]

445


446it [2:24:33, 19.01s/it]

446


447it [2:24:51, 18.78s/it]

447


448it [2:25:10, 18.86s/it]

448


449it [2:25:29, 18.99s/it]

449


450it [2:25:47, 18.55s/it]

450


451it [2:26:05, 18.47s/it]

451


452it [2:26:24, 18.56s/it]

452


453it [2:26:43, 18.83s/it]

453


454it [2:27:02, 18.91s/it]

454


455it [2:27:21, 18.85s/it]

455


456it [2:27:40, 18.95s/it]

456


457it [2:27:59, 18.96s/it]

457


458it [2:28:18, 18.94s/it]

458


459it [2:28:37, 19.03s/it]

459


460it [2:28:57, 19.16s/it]

460


461it [2:29:16, 19.04s/it]

461


462it [2:29:35, 19.07s/it]

462


463it [2:29:54, 19.02s/it]

463


464it [2:30:12, 18.90s/it]

464


465it [2:30:31, 18.97s/it]

465


466it [2:30:50, 18.99s/it]

466


467it [2:31:09, 18.96s/it]

467


468it [2:31:29, 19.17s/it]

468


469it [2:31:48, 19.21s/it]

469


470it [2:32:07, 19.13s/it]

470


471it [2:32:27, 19.29s/it]

471


472it [2:32:47, 19.42s/it]

472


473it [2:33:07, 19.66s/it]

473


474it [2:33:26, 19.56s/it]

474


475it [2:33:45, 19.33s/it]

475


476it [2:34:05, 19.70s/it]

476


477it [2:34:25, 19.78s/it]

477


478it [2:34:44, 19.47s/it]

478


479it [2:35:03, 19.21s/it]

479


480it [2:35:21, 19.00s/it]

480


481it [2:35:41, 19.28s/it]

481


482it [2:36:01, 19.30s/it]

482


483it [2:36:19, 18.98s/it]

483


484it [2:36:38, 19.01s/it]

484


485it [2:36:56, 18.73s/it]

485


486it [2:37:15, 18.86s/it]

486


487it [2:37:34, 18.74s/it]

487


488it [2:37:53, 18.94s/it]

488


489it [2:38:11, 18.81s/it]

489


490it [2:38:31, 18.93s/it]

490


491it [2:38:50, 18.93s/it]

491


492it [2:39:08, 18.68s/it]

492


493it [2:39:29, 19.42s/it]

493


494it [2:39:47, 19.14s/it]

494


495it [2:40:06, 18.96s/it]

495


496it [2:40:25, 19.06s/it]

496


497it [2:40:45, 19.21s/it]

497


498it [2:41:05, 19.39s/it]

498


499it [2:41:24, 19.42s/it]

499


500it [2:41:43, 19.37s/it]

500


501it [2:42:03, 19.38s/it]

501


502it [2:42:22, 19.36s/it]

502


503it [2:42:45, 20.38s/it]

503


504it [2:43:07, 20.93s/it]

504


505it [2:43:30, 21.49s/it]

505


506it [2:43:52, 21.69s/it]

506


507it [2:44:14, 21.92s/it]

507


508it [2:44:36, 21.87s/it]

508


509it [2:44:59, 22.08s/it]

509


510it [2:45:22, 22.29s/it]

510


511it [2:45:45, 22.50s/it]

511


512it [2:46:07, 22.51s/it]

512


513it [2:46:29, 22.43s/it]

513


514it [2:46:52, 22.58s/it]

514


515it [2:47:14, 22.43s/it]

515


516it [2:47:37, 22.42s/it]

516


517it [2:48:00, 22.70s/it]

517


518it [2:48:23, 22.72s/it]

518


519it [2:48:45, 22.70s/it]

519


520it [2:49:08, 22.63s/it]

520


521it [2:49:29, 22.30s/it]

521


522it [2:49:52, 22.41s/it]

522


523it [2:50:14, 22.14s/it]

523


524it [2:50:36, 22.28s/it]

524


525it [2:50:58, 22.21s/it]

525


526it [2:51:18, 21.38s/it]

526


527it [2:51:37, 20.88s/it]

527


528it [2:51:57, 20.37s/it]

528


529it [2:52:16, 20.12s/it]

529


530it [2:52:35, 19.81s/it]

530


531it [2:52:53, 19.33s/it]

531


532it [2:53:12, 19.00s/it]

532


533it [2:53:30, 18.90s/it]

533


534it [2:53:49, 18.77s/it]

534


535it [2:54:07, 18.58s/it]

535


536it [2:54:25, 18.50s/it]

536


537it [2:54:44, 18.53s/it]

537


538it [2:55:04, 18.95s/it]

538


539it [2:55:22, 18.77s/it]

539


540it [2:55:40, 18.56s/it]

540


541it [2:56:00, 18.78s/it]

541


542it [2:56:19, 18.94s/it]

542


543it [2:56:38, 19.12s/it]

543


544it [2:56:58, 19.16s/it]

544


545it [2:57:16, 18.89s/it]

545


546it [2:57:36, 19.35s/it]

546


547it [2:57:55, 19.27s/it]

547


548it [2:58:16, 19.60s/it]

548


549it [2:58:35, 19.45s/it]

549


550it [2:58:54, 19.35s/it]

550


551it [2:59:13, 19.22s/it]

551


552it [2:59:31, 19.03s/it]

552


553it [2:59:51, 19.30s/it]

553


554it [3:00:10, 18.98s/it]

554


555it [3:00:28, 18.88s/it]

555


556it [3:00:47, 18.97s/it]

556


557it [3:01:06, 18.98s/it]

557


558it [3:01:26, 19.10s/it]

558


559it [3:01:45, 19.12s/it]

559


560it [3:02:04, 19.11s/it]

560


561it [3:02:23, 19.06s/it]

561


562it [3:02:43, 19.31s/it]

562


563it [3:03:02, 19.18s/it]

563


564it [3:03:21, 19.22s/it]

564


565it [3:03:40, 19.12s/it]

565


566it [3:03:59, 19.19s/it]

566


567it [3:04:19, 19.18s/it]

567


568it [3:04:38, 19.16s/it]

568


569it [3:04:58, 19.51s/it]

569


570it [3:05:17, 19.38s/it]

570


571it [3:05:36, 19.37s/it]

571


572it [3:05:55, 19.12s/it]

572


573it [3:06:16, 19.63s/it]

573


574it [3:06:35, 19.51s/it]

574


575it [3:06:55, 19.65s/it]

575


576it [3:07:14, 19.50s/it]

576


577it [3:07:33, 19.29s/it]

577


578it [3:07:52, 19.26s/it]

578


579it [3:08:10, 18.97s/it]

579


580it [3:08:30, 19.07s/it]

580


581it [3:08:48, 19.00s/it]

581


582it [3:09:08, 19.09s/it]

582


583it [3:09:28, 19.34s/it]

583


584it [3:09:47, 19.38s/it]

584


585it [3:10:06, 19.28s/it]

585


586it [3:10:25, 19.25s/it]

586


587it [3:10:44, 19.19s/it]

587


588it [3:11:04, 19.38s/it]

588


589it [3:11:24, 19.40s/it]

589


590it [3:11:43, 19.32s/it]

590


591it [3:12:02, 19.14s/it]

591


592it [3:12:20, 18.89s/it]

592


593it [3:12:39, 18.99s/it]

593


594it [3:12:59, 19.26s/it]

594


595it [3:13:18, 19.24s/it]

595


596it [3:13:38, 19.32s/it]

596


597it [3:13:57, 19.19s/it]

597


598it [3:14:15, 18.93s/it]

598


599it [3:14:34, 18.93s/it]

599


600it [3:14:53, 18.98s/it]

600


601it [3:15:12, 18.88s/it]

601


602it [3:15:31, 19.04s/it]

602


603it [3:15:50, 18.91s/it]

603


604it [3:16:09, 18.98s/it]

604


605it [3:16:28, 18.92s/it]

605


606it [3:16:46, 18.74s/it]

606


607it [3:17:05, 18.76s/it]

607


608it [3:17:24, 18.78s/it]

608


609it [3:17:43, 18.93s/it]

609


610it [3:18:02, 19.03s/it]

610


611it [3:18:21, 19.00s/it]

611


612it [3:18:40, 19.06s/it]

612


613it [3:18:59, 18.96s/it]

613


614it [3:19:18, 19.14s/it]

614


615it [3:19:37, 19.05s/it]

615


616it [3:19:56, 18.98s/it]

616


617it [3:20:15, 18.99s/it]

617


618it [3:20:34, 19.00s/it]

618


619it [3:20:53, 19.06s/it]

619


620it [3:21:13, 19.11s/it]

620


621it [3:21:31, 18.94s/it]

621


622it [3:21:50, 19.01s/it]

622


623it [3:22:10, 19.26s/it]

623


624it [3:22:29, 19.27s/it]

624


625it [3:22:48, 19.07s/it]

625


626it [3:23:07, 19.11s/it]

626


627it [3:23:26, 19.01s/it]

627


628it [3:23:46, 19.21s/it]

628


629it [3:24:06, 19.46s/it]

629


630it [3:24:25, 19.46s/it]

630


631it [3:24:44, 19.23s/it]

631


632it [3:25:03, 19.24s/it]

632


633it [3:25:22, 19.18s/it]

633


634it [3:25:42, 19.23s/it]

634


635it [3:26:01, 19.29s/it]

635


636it [3:26:21, 19.51s/it]

636


637it [3:26:40, 19.36s/it]

637


638it [3:26:58, 19.08s/it]

638


639it [3:27:18, 19.14s/it]

639


640it [3:27:37, 19.11s/it]

640


641it [3:27:56, 19.27s/it]

641


642it [3:28:15, 18.96s/it]

642


643it [3:28:34, 19.22s/it]

643


644it [3:28:54, 19.34s/it]

644


645it [3:29:12, 18.91s/it]

645


646it [3:29:31, 19.08s/it]

646


647it [3:29:51, 19.22s/it]

647


648it [3:30:10, 19.26s/it]

648


649it [3:30:30, 19.32s/it]

649


650it [3:30:50, 19.50s/it]

650


651it [3:31:08, 19.05s/it]

651


652it [3:31:31, 20.45s/it]

652


653it [3:31:54, 21.13s/it]

653


654it [3:32:17, 21.78s/it]

654


655it [3:32:39, 21.64s/it]

655


656it [3:33:01, 21.77s/it]

656


657it [3:33:23, 21.87s/it]

657


658it [3:33:45, 21.83s/it]

658


659it [3:34:08, 22.14s/it]

659


660it [3:34:30, 22.26s/it]

660


661it [3:34:53, 22.32s/it]

661


662it [3:35:12, 21.36s/it]

662


663it [3:35:31, 20.65s/it]

663


664it [3:35:49, 20.05s/it]

664


665it [3:36:09, 19.89s/it]

665


666it [3:36:28, 19.71s/it]

666


667it [3:36:47, 19.40s/it]

667


668it [3:37:07, 19.59s/it]

668


669it [3:37:27, 19.86s/it]

669


670it [3:37:47, 19.68s/it]

670


671it [3:38:07, 19.79s/it]

671


672it [3:38:26, 19.70s/it]

672


673it [3:38:45, 19.41s/it]

673


674it [3:39:05, 19.52s/it]

674


675it [3:39:24, 19.56s/it]

675


676it [3:39:44, 19.47s/it]

676


677it [3:40:03, 19.47s/it]

677


678it [3:40:24, 19.95s/it]

678


679it [3:40:44, 19.85s/it]

679


680it [3:41:03, 19.76s/it]

680


681it [3:41:23, 19.75s/it]

681


682it [3:41:42, 19.66s/it]

682


683it [3:42:02, 19.51s/it]

683


684it [3:42:21, 19.43s/it]

684


685it [3:42:40, 19.29s/it]

685


686it [3:42:58, 18.98s/it]

686


687it [3:43:17, 18.88s/it]

687


688it [3:43:36, 18.99s/it]

688


689it [3:43:56, 19.22s/it]

689


690it [3:44:15, 19.17s/it]

690


691it [3:44:34, 19.13s/it]

691


692it [3:44:53, 19.12s/it]

692


693it [3:45:11, 18.95s/it]

693


694it [3:45:30, 18.83s/it]

694


695it [3:45:49, 18.90s/it]

695


696it [3:46:09, 19.29s/it]

696


697it [3:46:29, 19.34s/it]

697


698it [3:46:49, 19.48s/it]

698


699it [3:47:08, 19.50s/it]

699


700it [3:47:27, 19.43s/it]

700


701it [3:47:46, 19.30s/it]

701


702it [3:48:05, 19.25s/it]

702


703it [3:48:25, 19.19s/it]

703


704it [3:48:44, 19.37s/it]

704


705it [3:49:02, 18.95s/it]

705


706it [3:49:22, 19.20s/it]

706


707it [3:49:41, 19.09s/it]

707


708it [3:50:00, 19.23s/it]

708


709it [3:50:19, 19.15s/it]

709


710it [3:50:39, 19.40s/it]

710


711it [3:50:58, 19.24s/it]

711


712it [3:51:17, 19.22s/it]

712


713it [3:51:38, 19.48s/it]

713


714it [3:51:57, 19.48s/it]

714


715it [3:52:16, 19.35s/it]

715


716it [3:52:34, 19.06s/it]

716


717it [3:52:53, 18.99s/it]

717


718it [3:53:12, 19.05s/it]

718


719it [3:53:32, 19.10s/it]

719


720it [3:53:50, 18.98s/it]

720


721it [3:54:09, 18.77s/it]

721


722it [3:54:28, 18.88s/it]

722


In [None]:
torch.cuda.get_device_name(0)

In [None]:
bs                 = 16
game_data          = dict()

In [None]:
game_data 

In [None]:
dataset.set_format("pandas")

In [None]:
df_batch           = dataset[:].sample(bs)
df_batch 

In [None]:
game_data["query"] = df_batch["query"].tolist()
query_tensors      = df_batch["input_ids"].tolist()

In [None]:
response_tensors_ref, response_tensors = [], []

In [None]:
gen_kwargs = {
         "min_length":   -1, 
         "top_k":       0.0, 
         "top_p":       1.0, 
         "do_sample":  True, 
         "pad_token_id": tokenizer.eos_token_id
}

In [None]:
for i in range(bs):
    gen_len = output_length_sampler()
    
    output  = ref_model.generate(
        torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs
    ).squeeze()[-gen_len:]
    response_tensors_ref.append(output)
    
    
    output = model.generate(
        torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs
    ).squeeze()[-gen_len:]
    response_tensors.append(output)

In [None]:
game_data["response (before)"] = [tokenizer.decode(response_tensors_ref[i]) for i in range(bs)]
game_data["response (after)"]  = [tokenizer.decode(response_tensors[i]) for i in range(bs)]

In [None]:
texts = [q + r for q, r in zip(game_data["query"], game_data["response (before)"])]
game_data["rewards (before)"] = [output[1]["score"] for output in pipe(texts, **sent_kwargs)]

In [None]:
texts = [q + r for q, r in zip(game_data["query"], game_data["response (after)"])]
game_data["rewards (after)"] = [output[1]["score"] for output in pipe(texts, **sent_kwargs)]

In [None]:
df_results = pd.DataFrame(game_data)
df_results

In [None]:
!nvidia-smi