In [None]:
pip install --upgrade transformers huggingface_hub --q

In [2]:
import os 
import json 
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSequenceClassification
from peft import LoraConfig, TaskType, get_peft_model

In [3]:
import torch

In [4]:
model_name="EleutherAI/pythia-410m"

In [5]:
base_model = AutoModelForCausalLM.from_pretrained(model_name,attn_implementation="sdpa",dtype=torch.float16, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained(model_name)

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/911M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/396 [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

In [6]:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
tokenizer.padding_side = 'left'

In [7]:
from peft import PeftModel

In [8]:
policy_model = PeftModel.from_pretrained(
    base_model,
    "/kaggle/input/loraadapters/pytorch/default/1",
    trainable=True
)

In [9]:
# config = LoraConfig(
#     r=8,
#     lora_alpha=16,
#     lora_dropout=0.1,
#     bias='none',
#     target_modules=["query_key_value"],
#     task_type=TaskType.CAUSAL_LM
# )
# lora_model = get_peft_model(model, config)
# optimizer = torch.optim.AdamW(lora_model.parameters(), lr=1e-4)

# lora_model, optimizer = accelerator.prepare(lora_model, optimizer)

# accelerator.load_state("/kaggle/input/bestmodel/pytorch/default/1")

In [10]:
# lora_model.save_pretrained("/kaggle/working/lora_adapters")

In [11]:
value_model = AutoModelForSequenceClassification.from_pretrained(
    model_name, num_labels=1
)

Some weights of GPTNeoXForSequenceClassification were not initialized from the model checkpoint at EleutherAI/pythia-410m and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [12]:
value_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["query_key_value"],
    lora_dropout=0.05,
    bias='none',
    task_type=TaskType.SEQ_CLS
)

In [13]:
value_model = get_peft_model(value_model, value_config)
value_model.to('cuda')
value_model.config.pad_token_id = tokenizer.pad_token_id

In [14]:
value_model.print_trainable_parameters()

trainable params: 1,573,888 || all params: 355,397,632 || trainable%: 0.4429


In [15]:
value_model.device


device(type='cuda', index=0)

In [16]:
ref_model = PeftModel.from_pretrained(
    base_model,
    "/kaggle/input/loraadapters/pytorch/default/1",
    trainable=False
)



In [17]:
with open("/kaggle/input/mathdataset/math_dataset.json", 'r') as f:
    dataset = json.load(f)

In [18]:
from datasets import Dataset

In [19]:
train_data = []
for item in dataset['train']:
    train_data.append({
        'question': item['question'],
        'answer': item['answer']
    })

train_dataset = Dataset.from_list(train_data)

In [20]:
def tokenize_function(sample):
    tokenized = tokenizer(
        sample['question'], 
        truncation=True, 
        padding='max_length',
        max_length=256
    )
    return {
        'input_ids': tokenized['input_ids'],
        'attention_mask': tokenized['attention_mask'],
        'answer': sample['answer'],  # Keep this!
        'question': sample['question']  # Keep this too!
    }

tokenized_dataset = train_dataset.map(
    tokenize_function,
    remove_columns=['question', 'answer']  # Remove only after copying
)

Map:   0%|          | 0/738 [00:00<?, ? examples/s]

In [21]:
tokenized_dataset

Dataset({
    features: ['question', 'answer', 'input_ids', 'attention_mask'],
    num_rows: 738
})

In [22]:
from typing import Callable, Optional, Tuple, List
from transformers import PreTrainedTokenizerBase
from tqdm import tqdm
from torch.utils.data import DataLoader
import torch.nn.functional as F

In [23]:
def compute_reward(response: torch.Tensor, answers: List)->torch.Tensor:
    rewards = []
    for response, answer in zip(response, answers):
        response_ = response.lower().strip()
        answer_ = answer.lower().strip()

        if answer_ in response_:
            reward = 1.0
        else:
            reward= -1.0
        rewards.append(reward)
    return torch.tensor(rewards, dtype=torch.float16)

In [24]:
def compute_gae(reward: torch.Tensor, values:torch.Tensor, gamma: float=0.99, lam: float=0.95)->Tuple[torch.Tensor, torch.Tensor]:
    if values.dim() == 1:
        advantages = reward - values  # (batch_size,)
        returns = reward  # (batch_size,)
        return advantages.unsqueeze(-1), returns.unsqueeze(-1)
 
    batch_size, seq_len = values.shape
    rewards = torch.zeros_like(values)
    rewards[:, -1] = reward

    next_values = torch.zeros_like(values)
    next_values[:, :-1] = values[:,1:]
    deltas = rewards+gamma * next_values - values
    advantages = torch.zeros_like(values)
    gae = torch.zeros(batch_size).to(values.device)
    for t in reversed(range(seq_len)):
        gae = deltas[:, t]+gamma*lam*gae
        advantages[:, t] = gae

    returns = advantages + values 
    return advantages, returns

In [25]:
def get_log_probs(model, input_ids, attention_mask):
    """Get log probabilities for generated tokens"""
    outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    logits = outputs.logits[:, :-1, :]
    log_probs = F.log_softmax(logits, dim=-1)
    
    next_tokens = input_ids[:, 1:]
    token_log_probs = torch.gather(log_probs, 2, next_tokens.unsqueeze(-1)).squeeze(-1)
    
    mask = attention_mask[:, 1:].float()
    token_log_probs = token_log_probs * mask
    
    return token_log_probs, mask

In [26]:
for n, p in policy_model.named_parameters():
    if 'lora' in n:
        p.requires_grad = True

In [27]:
policy_model.print_trainable_parameters()

trainable params: 786,432 || all params: 406,120,448 || trainable%: 0.1936


In [28]:
def compute_entropy(logits, mask):
    probs = F.softmax(logits, dim=-1)
    log_probs = F.log_softmax(logits, dim=-1)
    entropy = -(probs * log_probs).sum(dim=-1)
    return (entropy * mask).sum(dim=-1) / mask.sum(dim=-1)

In [29]:
def compute_reward(responses: List[str], answers: List[str]) -> torch.Tensor:
    """Better reward with partial credit"""
    from difflib import SequenceMatcher
    
    rewards = []
    for response, answer in zip(responses, answers):
        response_clean = response.lower().strip()
        answer_clean = answer.lower().strip()
        
        # Exact match in response
        if answer_clean in response_clean:
            reward = 1.0
        
        # Partial match - string similarity
        elif SequenceMatcher(None, answer_clean, response_clean).ratio() > 0.6:
            reward = 0.5
        
        # Answer words present
        elif any(word in response_clean for word in answer_clean.split() if len(word) > 2):
            reward = 0.2
        
        # At least it tried (not empty)
        elif len(response_clean) > 10:
            reward = -0.3  # Small penalty, not harsh
        
        else:
            reward = -1.0
        response_len = len(response_clean.split())
        answer_len = len(answer_clean.split())

        if response_len > 0:
            len_ratio = answer_len / response_len
            len_ratio = max(0.0, min(len_ratio, 1.5))
            reward += 0.3 * len_ratio  # weight conciseness
            
        rewards.append(reward)
    
    return torch.tensor(rewards, dtype=torch.float16)

In [30]:
def ppo_trainer(
    policy_model: PeftModel, 
    value_model: PeftModel,
    ref_model: PeftModel,
    tokenizer: PreTrainedTokenizerBase,
    dataset: Dataset,
    policy_optimizer: torch.optim.AdamW,
    value_optimizer: torch.optim.AdamW,
    reward_model: Optional[Callable[[list[str], list[str]], list[float]]] = None,
    num_epochs=3,
    rollout_batch_size=32,
    mini_batch_size=8,
    ppo_epochs=4,
    clip_eps: float=0.1,
    kl_coef: float=0.01,
    value_coef: float=0.5,
    entropy_coef: float=0.01,
    gamma: float=0.99,
    lam: float=0.95,
    max_gen_len: int=128
):
    def collate_fn(batch):
        input_ids = torch.stack([torch.tensor(item['input_ids']) for item in batch])
        attention_mask = torch.stack([torch.tensor(item['attention_mask']) for item in batch])
        answers = [item['answer'] for item in batch]
        questions = [item['question'] for item in batch]
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'answers': answers,
            'questions': questions
        }
        
    dataloader = DataLoader(dataset, batch_size=rollout_batch_size, collate_fn=collate_fn,shuffle=True)
    all_epoch_rewards = []
    

    for epoch in range(num_epochs):
        print('-----------------------------')
        print(f"Epoch {epoch+1}/{num_epochs}")
        epoch_rewards = []
        for b_idx, batch in enumerate(tqdm(dataloader, desc='Rollouts')):
            policy_model.eval()
            value_model.eval()
            ref_model.eval()
            
            rollout_buffer=[]
            with torch.no_grad():
                input_ids = batch['input_ids'].to(policy_model.device)
                attention_mask = batch['attention_mask'].to(policy_model.device)
                answers = batch['answers']
                
                response_token = policy_model.generate(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    max_new_tokens=max_gen_len,
                    do_sample=True,
                    temperature=0.7,
                    top_p=0.9,
                    pad_token_id=tokenizer.pad_token_id
                )
                responses=tokenizer.batch_decode(response_token, skip_special_tokens=True)
                
                rewards = compute_reward(responses, answers)
                rewards = rewards.to(policy_model.device)
                batch_avg_reward = rewards.mean().item()
                epoch_rewards.append(batch_avg_reward)

                attention_mask = (response_token != tokenizer.pad_token_id).long()

                values = value_model(
                    input_ids=response_token,
                    attention_mask=attention_mask
                ).logits.squeeze(-1)

                old_log_probs, mask = get_log_probs(
                    policy_model, response_token, attention_mask
                )
                ref_log_probs,_ = get_log_probs(
                    ref_model, response_token, attention_mask
                )
                advantages, returns = compute_gae(rewards, values, gamma, lam)
                advantages = (advantages-advantages.mean()) / (advantages.std()+1e-8)

                for i in range(len(batch['answers'])):
                    rollout_buffer.append({
                        'input_ids': response_token[i].cpu(),
                        'attention_mask': attention_mask[i].cpu(),
                        'old_log_probs': old_log_probs[i].cpu(),
                        'ref_log_probs': ref_log_probs[i].cpu(),
                        'advantages': advantages[i].cpu(),
                        'returns': returns[i].cpu(),
                        'mask': mask[i].cpu()
                    })
            policy_model.train()
            value_model.train()

            for ppo_epoch in range(ppo_epochs):
                import random
                random.shuffle(rollout_buffer)
                num_minibatches = len(rollout_buffer)//mini_batch_size

                ppo_policy_losses =[]
                ppo_value_losses =[]
                ppo_kl_penalties = []
                ppo_entropies = []
                for mb_idx in range(num_minibatches):
                    start_idx = mb_idx* mini_batch_size
                    end_idx = start_idx + mini_batch_size 
                    minibatch = rollout_buffer[start_idx:end_idx]
                    
                    #stack minibatches
                    mb_input_ids = torch.stack([item['input_ids']for item in minibatch])
                    mb_attention_mask = torch.stack([item['attention_mask'] for item in minibatch])
                    mb_old_log_probs = torch.stack([item['old_log_probs'] for item in minibatch])
                    mb_ref_log_probs = torch.stack([item['ref_log_probs'] for item in minibatch])
                    mb_advantages = torch.stack([item['advantages'] for item in minibatch])
                    mb_returns = torch.stack([item['returns'] for item in minibatch])
                    mb_mask = torch.stack([item['mask'] for item in minibatch])

                    #move to device
                    mb_input_ids = mb_input_ids.to(policy_model.device)
                    mb_attention_mask = mb_attention_mask.to(policy_model.device)
                    mb_advantages = mb_advantages.to(policy_model.device)
                    mb_old_log_probs = mb_old_log_probs.to(policy_model.device)
                    mb_ref_log_probs = mb_ref_log_probs.to(policy_model.device)
                    mb_returns = mb_returns.to(policy_model.device)
                    mb_mask = mb_mask.to(policy_model.device)
                    
                    policy_optimizer.zero_grad()
                    new_log_probs, _ = get_log_probs(
                        policy_model,
                        mb_input_ids,
                        mb_attention_mask
                    )
                    
                    #PPO loss
                    ratio = torch.exp(new_log_probs - mb_old_log_probs)    
                    surr1 = ratio*mb_advantages*mb_mask
                    surr2 = torch.clamp(ratio, 1-clip_eps, 1+clip_eps)* mb_advantages * mb_mask
                    policy_loss = -torch.min(surr1, surr2).sum() / mb_mask.sum()
                    #KL 
                    kl_div = (mb_ref_log_probs.exp() * (mb_ref_log_probs - new_log_probs)) * mb_mask
                    kl_penalty = kl_div.sum() / mb_mask.sum()
                    #entropy
                    outputs = policy_model(
                        input_ids=mb_input_ids,
                        attention_mask=mb_attention_mask
                    )
                    
                    entropy=compute_entropy(outputs.logits[:, :-1, :], mb_mask)

                    total_policy_loss = policy_loss + kl_coef * kl_penalty - entropy_coef * entropy.mean()
                    #backward
                    if not torch.isnan(total_policy_loss):
                        total_policy_loss.backward()
                        torch.nn.utils.clip_grad_norm_(policy_model.parameters(), max_norm=1.0)
                        policy_optimizer.step()

                    #value update
                    value_optimizer.zero_grad()
                    new_values = value_model(
                        input_ids=mb_input_ids,
                        attention_mask=mb_attention_mask
                    ).logits.squeeze(-1)
                    
                    # if new_values.dim() == 1:
                    #     seq_len = mb_input_ids.shape[1]
                    #     new_values = new_values.unsqueeze(-1).expand(-1, seq_len)
                    final_returns = mb_returns[:, -1]
                    #value loss
                    value_loss = value_coef * ((new_values.squeeze() - mb_returns.squeeze()) ** 2).mean()

                    # Backward and optimize
                    if not torch.isnan(value_loss):
                        value_loss.backward()
                        torch.nn.utils.clip_grad_norm_(value_model.parameters(), max_norm=1.0)
                        value_optimizer.step()

                    ppo_policy_losses.append(policy_loss.item())
                    ppo_value_losses.append(value_loss.item())
                    ppo_kl_penalties.append(kl_penalty.item())
                    ppo_entropies.append(entropy.mean().item())
                
                if b_idx % 10 == 0 and len(ppo_policy_losses) > 0:
                    print(f"\n  Batch {b_idx} - PPO Epoch {ppo_epoch+1}/{ppo_epochs}")
                    print(f"    Policy Loss: {sum(ppo_policy_losses)/len(ppo_policy_losses):.4f}")
                    print(f"    Value Loss: {sum(ppo_value_losses)/len(ppo_value_losses):.4f}")
                    print(f"    KL Div: {sum(ppo_kl_penalties)/len(ppo_kl_penalties):.4f}")
                    print(f"    Entropy: {sum(ppo_entropies)/len(ppo_entropies):.4f}")
                    print(f"    Avg Reward: {batch_avg_reward:.4f}")
            
            del rollout_buffer
            torch.cuda.empty_cache()     
            
        epoch_avg_reward = sum(epoch_rewards) / len(epoch_rewards)
        all_epoch_rewards.append(epoch_avg_reward)
        print(f"\n{'='*50}")
        print(f"Epoch {epoch+1} Average Reward: {epoch_avg_reward:.4f}")
        print(f"{'='*50}\n")
    
    return policy_model, value_model

In [31]:
policy_optimizer = torch.optim.AdamW(policy_model.parameters(), lr=1e-4)
value_optimizer = torch.optim.AdamW(value_model.parameters(), lr=1e-4)

In [32]:
import os
os.environ['CUDA_LAUNCH_BLOCKING']='1'
os.environ["TORCH_USE_CUDA_DSA"] = '1'

In [33]:
trained_policy, trained_value = ppo_trainer(
    policy_model=policy_model,
    value_model=value_model,
    ref_model=ref_model,
    tokenizer=tokenizer,
    dataset=tokenized_dataset,
    policy_optimizer=policy_optimizer,
    value_optimizer=value_optimizer,
    reward_model=None,
    num_epochs=8,  
    rollout_batch_size=16,  
    mini_batch_size=4,
    ppo_epochs=4,
    clip_eps=0.1,
    kl_coef=0.3,
    value_coef=0.5,
    entropy_coef=0.01,
    gamma=0.85,
    lam=0.87,
    max_gen_len=128 
)

-----------------------------
Epoch 1/8


Rollouts:   0%|          | 0/47 [00:00<?, ?it/s]


  Batch 0 - PPO Epoch 1/4
    Policy Loss: 0.1023
    Value Loss: 1.2231
    KL Div: 0.0027
    Entropy: 1.7889
    Avg Reward: -0.2134

  Batch 0 - PPO Epoch 2/4
    Policy Loss: 0.0887
    Value Loss: 0.5439
    KL Div: 0.0077
    Entropy: 1.8084
    Avg Reward: -0.2134

  Batch 0 - PPO Epoch 3/4
    Policy Loss: 0.0399
    Value Loss: 0.1944
    KL Div: 0.0112
    Entropy: 1.8183
    Avg Reward: -0.2134


Rollouts:   2%|▏         | 1/47 [00:24<18:53, 24.64s/it]


  Batch 0 - PPO Epoch 4/4
    Policy Loss: 0.0457
    Value Loss: 0.1589
    KL Div: 0.0097
    Entropy: 1.8161
    Avg Reward: -0.2134


Rollouts:  21%|██▏       | 10/47 [04:02<14:55, 24.20s/it]


  Batch 10 - PPO Epoch 1/4
    Policy Loss: 0.0094
    Value Loss: 0.5084
    KL Div: 0.0036
    Entropy: 1.3603
    Avg Reward: -0.2139

  Batch 10 - PPO Epoch 2/4
    Policy Loss: 0.0059
    Value Loss: 0.1620
    KL Div: 0.0049
    Entropy: 1.3548
    Avg Reward: -0.2139

  Batch 10 - PPO Epoch 3/4
    Policy Loss: -0.0007
    Value Loss: 0.0411
    KL Div: 0.0040
    Entropy: 1.3496
    Avg Reward: -0.2139


Rollouts:  23%|██▎       | 11/47 [04:26<14:31, 24.20s/it]


  Batch 10 - PPO Epoch 4/4
    Policy Loss: 0.0026
    Value Loss: 0.1277
    KL Div: 0.0046
    Entropy: 1.3409
    Avg Reward: -0.2139


Rollouts:  43%|████▎     | 20/47 [08:03<10:52, 24.18s/it]


  Batch 20 - PPO Epoch 1/4
    Policy Loss: 0.0114
    Value Loss: 0.3704
    KL Div: 0.0071
    Entropy: 1.1497
    Avg Reward: -0.2141

  Batch 20 - PPO Epoch 2/4
    Policy Loss: 0.0112
    Value Loss: 0.1183
    KL Div: 0.0066
    Entropy: 1.1522
    Avg Reward: -0.2141

  Batch 20 - PPO Epoch 3/4
    Policy Loss: 0.0097
    Value Loss: 0.0531
    KL Div: 0.0096
    Entropy: 1.1480
    Avg Reward: -0.2141


Rollouts:  45%|████▍     | 21/47 [08:28<10:29, 24.20s/it]


  Batch 20 - PPO Epoch 4/4
    Policy Loss: 0.0080
    Value Loss: 0.0732
    KL Div: 0.0121
    Entropy: 1.1588
    Avg Reward: -0.2141


Rollouts:  64%|██████▍   | 30/47 [12:05<06:51, 24.18s/it]


  Batch 30 - PPO Epoch 1/4
    Policy Loss: 0.0580
    Value Loss: 0.1854
    KL Div: 0.0124
    Entropy: 1.0736
    Avg Reward: -0.0509

  Batch 30 - PPO Epoch 2/4
    Policy Loss: 0.0909
    Value Loss: 0.0774
    KL Div: 0.0098
    Entropy: 1.0791
    Avg Reward: -0.0509

  Batch 30 - PPO Epoch 3/4
    Policy Loss: 0.0921
    Value Loss: 0.0513
    KL Div: 0.0206
    Entropy: 1.0709
    Avg Reward: -0.0509


Rollouts:  66%|██████▌   | 31/47 [12:30<06:27, 24.20s/it]


  Batch 30 - PPO Epoch 4/4
    Policy Loss: 0.0576
    Value Loss: 0.0572
    KL Div: 0.0148
    Entropy: 1.0721
    Avg Reward: -0.0509


Rollouts:  85%|████████▌ | 40/47 [16:07<02:49, 24.17s/it]


  Batch 40 - PPO Epoch 1/4
    Policy Loss: 0.0132
    Value Loss: 0.1203
    KL Div: 0.0143
    Entropy: 0.7951
    Avg Reward: -0.2145

  Batch 40 - PPO Epoch 2/4
    Policy Loss: 0.0112
    Value Loss: 0.0599
    KL Div: 0.0243
    Entropy: 0.8126
    Avg Reward: -0.2145

  Batch 40 - PPO Epoch 3/4
    Policy Loss: 0.0082
    Value Loss: 0.0380
    KL Div: 0.0253
    Entropy: 0.8019
    Avg Reward: -0.2145


Rollouts:  87%|████████▋ | 41/47 [16:31<02:25, 24.17s/it]


  Batch 40 - PPO Epoch 4/4
    Policy Loss: -0.0026
    Value Loss: 0.0229
    KL Div: 0.0146
    Entropy: 0.7843
    Avg Reward: -0.2145


Rollouts: 100%|██████████| 47/47 [18:37<00:00, 23.79s/it]



Epoch 1 Average Reward: -0.1822

-----------------------------
Epoch 2/8


Rollouts:   0%|          | 0/47 [00:00<?, ?it/s]


  Batch 0 - PPO Epoch 1/4
    Policy Loss: 0.0173
    Value Loss: 0.0491
    KL Div: 0.0085
    Entropy: 0.4824
    Avg Reward: -0.2954

  Batch 0 - PPO Epoch 2/4
    Policy Loss: 0.0082
    Value Loss: 0.0602
    KL Div: 0.0065
    Entropy: 0.5063
    Avg Reward: -0.2954

  Batch 0 - PPO Epoch 3/4
    Policy Loss: 0.0076
    Value Loss: 0.0198
    KL Div: 0.0097
    Entropy: 0.5245
    Avg Reward: -0.2954


Rollouts:   2%|▏         | 1/47 [00:24<18:32, 24.18s/it]


  Batch 0 - PPO Epoch 4/4
    Policy Loss: 0.0088
    Value Loss: 0.0202
    KL Div: 0.0290
    Entropy: 0.5470
    Avg Reward: -0.2954


Rollouts:  21%|██▏       | 10/47 [04:01<14:52, 24.12s/it]


  Batch 10 - PPO Epoch 1/4
    Policy Loss: 0.0101
    Value Loss: 0.0347
    KL Div: 0.0053
    Entropy: 0.1906
    Avg Reward: -0.2954

  Batch 10 - PPO Epoch 2/4
    Policy Loss: 0.0151
    Value Loss: 0.0174
    KL Div: 0.0063
    Entropy: 0.2053
    Avg Reward: -0.2954

  Batch 10 - PPO Epoch 3/4
    Policy Loss: 0.0111
    Value Loss: 0.0196
    KL Div: 0.0090
    Entropy: 0.2017
    Avg Reward: -0.2954


Rollouts:  23%|██▎       | 11/47 [04:25<14:28, 24.12s/it]


  Batch 10 - PPO Epoch 4/4
    Policy Loss: 0.0098
    Value Loss: 0.0151
    KL Div: 0.0082
    Entropy: 0.2027
    Avg Reward: -0.2954


Rollouts:  43%|████▎     | 20/47 [08:02<10:53, 24.22s/it]


  Batch 20 - PPO Epoch 1/4
    Policy Loss: 0.0139
    Value Loss: 0.0179
    KL Div: 0.0073
    Entropy: 0.1169
    Avg Reward: -0.2605

  Batch 20 - PPO Epoch 2/4
    Policy Loss: 0.0072
    Value Loss: 0.0058
    KL Div: 0.0021
    Entropy: 0.1184
    Avg Reward: -0.2605

  Batch 20 - PPO Epoch 3/4
    Policy Loss: 0.0000
    Value Loss: 0.0059
    KL Div: 0.0037
    Entropy: 0.1039
    Avg Reward: -0.2605


Rollouts:  45%|████▍     | 21/47 [08:26<10:28, 24.17s/it]


  Batch 20 - PPO Epoch 4/4
    Policy Loss: 0.0040
    Value Loss: 0.0020
    KL Div: 0.0024
    Entropy: 0.1032
    Avg Reward: -0.2605


Rollouts:  64%|██████▍   | 30/47 [12:03<06:49, 24.09s/it]


  Batch 30 - PPO Epoch 1/4
    Policy Loss: 0.0022
    Value Loss: 0.0534
    KL Div: 0.0004
    Entropy: 0.0592
    Avg Reward: -0.1731

  Batch 30 - PPO Epoch 2/4
    Policy Loss: 0.0036
    Value Loss: 0.0475
    KL Div: 0.0010
    Entropy: 0.0600
    Avg Reward: -0.1731

  Batch 30 - PPO Epoch 3/4
    Policy Loss: 0.0024
    Value Loss: 0.0476
    KL Div: 0.0007
    Entropy: 0.0602
    Avg Reward: -0.1731


Rollouts:  66%|██████▌   | 31/47 [12:27<06:25, 24.12s/it]


  Batch 30 - PPO Epoch 4/4
    Policy Loss: 0.0020
    Value Loss: 0.0505
    KL Div: 0.0020
    Entropy: 0.0603
    Avg Reward: -0.1731


Rollouts:  85%|████████▌ | 40/47 [16:04<02:48, 24.08s/it]


  Batch 40 - PPO Epoch 1/4
    Policy Loss: 0.0067
    Value Loss: 0.0009
    KL Div: 0.0007
    Entropy: 0.0625
    Avg Reward: -0.2537

  Batch 40 - PPO Epoch 2/4
    Policy Loss: 0.0082
    Value Loss: 0.0010
    KL Div: 0.0007
    Entropy: 0.0604
    Avg Reward: -0.2537

  Batch 40 - PPO Epoch 3/4
    Policy Loss: 0.0038
    Value Loss: 0.0020
    KL Div: 0.0005
    Entropy: 0.0595
    Avg Reward: -0.2537


Rollouts:  87%|████████▋ | 41/47 [16:28<02:24, 24.10s/it]


  Batch 40 - PPO Epoch 4/4
    Policy Loss: 0.0031
    Value Loss: 0.0015
    KL Div: 0.0006
    Entropy: 0.0664
    Avg Reward: -0.2537


Rollouts: 100%|██████████| 47/47 [18:34<00:00, 23.72s/it]



Epoch 2 Average Reward: -0.1908

-----------------------------
Epoch 3/8


Rollouts:   0%|          | 0/47 [00:00<?, ?it/s]


  Batch 0 - PPO Epoch 1/4
    Policy Loss: 0.0068
    Value Loss: 0.0485
    KL Div: 0.0088
    Entropy: 0.0828
    Avg Reward: -0.1857

  Batch 0 - PPO Epoch 2/4
    Policy Loss: -0.0005
    Value Loss: 0.0479
    KL Div: 0.0598
    Entropy: 0.1934
    Avg Reward: -0.1857

  Batch 0 - PPO Epoch 3/4
    Policy Loss: -0.0007
    Value Loss: 0.0638
    KL Div: 0.1004
    Entropy: 0.2449
    Avg Reward: -0.1857


Rollouts:   2%|▏         | 1/47 [00:24<18:32, 24.18s/it]


  Batch 0 - PPO Epoch 4/4
    Policy Loss: 0.0038
    Value Loss: 0.0640
    KL Div: 0.0118
    Entropy: 0.1129
    Avg Reward: -0.1857


Rollouts:  21%|██▏       | 10/47 [04:01<14:51, 24.10s/it]


  Batch 10 - PPO Epoch 1/4
    Policy Loss: 0.0002
    Value Loss: 0.1237
    KL Div: 0.0009
    Entropy: 0.1514
    Avg Reward: -0.1110

  Batch 10 - PPO Epoch 2/4
    Policy Loss: 0.0032
    Value Loss: 0.0962
    KL Div: -0.0107
    Entropy: 0.0975
    Avg Reward: -0.1110

  Batch 10 - PPO Epoch 3/4
    Policy Loss: 0.0008
    Value Loss: 0.1132
    KL Div: -0.0120
    Entropy: 0.0788
    Avg Reward: -0.1110


Rollouts:  23%|██▎       | 11/47 [04:25<14:28, 24.12s/it]


  Batch 10 - PPO Epoch 4/4
    Policy Loss: 0.0015
    Value Loss: 0.0986
    KL Div: -0.0045
    Entropy: 0.0949
    Avg Reward: -0.1110


Rollouts:  43%|████▎     | 20/47 [08:02<10:51, 24.13s/it]


  Batch 20 - PPO Epoch 1/4
    Policy Loss: 0.0022
    Value Loss: 0.1063
    KL Div: 0.0010
    Entropy: 0.0697
    Avg Reward: -0.1002

  Batch 20 - PPO Epoch 2/4
    Policy Loss: 0.0031
    Value Loss: 0.0941
    KL Div: 0.0023
    Entropy: 0.0761
    Avg Reward: -0.1002

  Batch 20 - PPO Epoch 3/4
    Policy Loss: 0.0016
    Value Loss: 0.0951
    KL Div: 0.0040
    Entropy: 0.0848
    Avg Reward: -0.1002


Rollouts:  45%|████▍     | 21/47 [08:26<10:27, 24.13s/it]


  Batch 20 - PPO Epoch 4/4
    Policy Loss: 0.0009
    Value Loss: 0.0977
    KL Div: 0.0073
    Entropy: 0.0892
    Avg Reward: -0.1002


Rollouts:  64%|██████▍   | 30/47 [12:03<06:49, 24.06s/it]


  Batch 30 - PPO Epoch 1/4
    Policy Loss: 0.0121
    Value Loss: 0.0474
    KL Div: 0.0001
    Entropy: 0.1262
    Avg Reward: -0.2620

  Batch 30 - PPO Epoch 2/4
    Policy Loss: 0.0091
    Value Loss: 0.0101
    KL Div: -0.0025
    Entropy: 0.1257
    Avg Reward: -0.2620

  Batch 30 - PPO Epoch 3/4
    Policy Loss: -0.0004
    Value Loss: 0.0274
    KL Div: 0.0076
    Entropy: 0.1011
    Avg Reward: -0.2620


Rollouts:  66%|██████▌   | 31/47 [12:27<06:24, 24.06s/it]


  Batch 30 - PPO Epoch 4/4
    Policy Loss: -0.0078
    Value Loss: 0.0103
    KL Div: 0.0164
    Entropy: 0.1099
    Avg Reward: -0.2620


Rollouts:  85%|████████▌ | 40/47 [16:04<02:48, 24.07s/it]


  Batch 40 - PPO Epoch 1/4
    Policy Loss: 0.0000
    Value Loss: 0.0444
    KL Div: 0.0003
    Entropy: 0.0658
    Avg Reward: -0.1731

  Batch 40 - PPO Epoch 2/4
    Policy Loss: 0.0003
    Value Loss: 0.0416
    KL Div: -0.0001
    Entropy: 0.0637
    Avg Reward: -0.1731

  Batch 40 - PPO Epoch 3/4
    Policy Loss: -0.0010
    Value Loss: 0.0453
    KL Div: -0.0005
    Entropy: 0.0593
    Avg Reward: -0.1731


Rollouts:  87%|████████▋ | 41/47 [16:28<02:24, 24.05s/it]


  Batch 40 - PPO Epoch 4/4
    Policy Loss: -0.0035
    Value Loss: 0.0449
    KL Div: -0.0010
    Entropy: 0.0578
    Avg Reward: -0.1731


Rollouts: 100%|██████████| 47/47 [18:34<00:00, 23.71s/it]



Epoch 3 Average Reward: -0.1832

-----------------------------
Epoch 4/8


Rollouts:   0%|          | 0/47 [00:00<?, ?it/s]


  Batch 0 - PPO Epoch 1/4
    Policy Loss: 0.0026
    Value Loss: 0.0039
    KL Div: 0.0011
    Entropy: 0.0994
    Avg Reward: -0.2605

  Batch 0 - PPO Epoch 2/4
    Policy Loss: -0.0011
    Value Loss: 0.0040
    KL Div: 0.0077
    Entropy: 0.1046
    Avg Reward: -0.2605

  Batch 0 - PPO Epoch 3/4
    Policy Loss: -0.0016
    Value Loss: 0.0047
    KL Div: 0.0058
    Entropy: 0.1032
    Avg Reward: -0.2605


Rollouts:   2%|▏         | 1/47 [00:24<18:30, 24.14s/it]


  Batch 0 - PPO Epoch 4/4
    Policy Loss: -0.0044
    Value Loss: 0.0030
    KL Div: 0.0013
    Entropy: 0.1063
    Avg Reward: -0.2605


Rollouts:  21%|██▏       | 10/47 [04:01<14:51, 24.10s/it]


  Batch 10 - PPO Epoch 1/4
    Policy Loss: 0.0112
    Value Loss: 0.0023
    KL Div: 0.0001
    Entropy: 0.0623
    Avg Reward: -0.2537

  Batch 10 - PPO Epoch 2/4
    Policy Loss: 0.0058
    Value Loss: 0.0025
    KL Div: -0.0003
    Entropy: 0.0586
    Avg Reward: -0.2537

  Batch 10 - PPO Epoch 3/4
    Policy Loss: 0.0047
    Value Loss: 0.0014
    KL Div: 0.0009
    Entropy: 0.0577
    Avg Reward: -0.2537


Rollouts:  23%|██▎       | 11/47 [04:25<14:27, 24.09s/it]


  Batch 10 - PPO Epoch 4/4
    Policy Loss: 0.0041
    Value Loss: 0.0004
    KL Div: 0.0010
    Entropy: 0.0576
    Avg Reward: -0.2537


Rollouts:  43%|████▎     | 20/47 [08:02<10:50, 24.09s/it]


  Batch 20 - PPO Epoch 1/4
    Policy Loss: 0.0021
    Value Loss: 0.0472
    KL Div: 0.0000
    Entropy: 0.0508
    Avg Reward: -0.1748

  Batch 20 - PPO Epoch 2/4
    Policy Loss: 0.0038
    Value Loss: 0.0510
    KL Div: 0.0000
    Entropy: 0.0506
    Avg Reward: -0.1748

  Batch 20 - PPO Epoch 3/4
    Policy Loss: -0.0004
    Value Loss: 0.0490
    KL Div: 0.0001
    Entropy: 0.0506
    Avg Reward: -0.1748


Rollouts:  45%|████▍     | 21/47 [08:26<10:26, 24.08s/it]


  Batch 20 - PPO Epoch 4/4
    Policy Loss: 0.0019
    Value Loss: 0.0477
    KL Div: 0.0001
    Entropy: 0.0508
    Avg Reward: -0.1748


Rollouts:  64%|██████▍   | 30/47 [12:02<06:49, 24.07s/it]


  Batch 30 - PPO Epoch 1/4
    Policy Loss: 0.0000
    Value Loss: 0.0805
    KL Div: 0.0001
    Entropy: 0.0503
    Avg Reward: -0.0906

  Batch 30 - PPO Epoch 2/4
    Policy Loss: -0.0010
    Value Loss: 0.0905
    KL Div: 0.0014
    Entropy: 0.0502
    Avg Reward: -0.0906

  Batch 30 - PPO Epoch 3/4
    Policy Loss: -0.0012
    Value Loss: 0.0640
    KL Div: -0.0000
    Entropy: 0.0501
    Avg Reward: -0.0906


Rollouts:  66%|██████▌   | 31/47 [12:26<06:24, 24.05s/it]


  Batch 30 - PPO Epoch 4/4
    Policy Loss: 0.0028
    Value Loss: 0.0826
    KL Div: -0.0000
    Entropy: 0.0504
    Avg Reward: -0.0906


Rollouts:  85%|████████▌ | 40/47 [16:03<02:48, 24.08s/it]


  Batch 40 - PPO Epoch 1/4
    Policy Loss: 0.0032
    Value Loss: 0.0516
    KL Div: 0.0001
    Entropy: 0.0514
    Avg Reward: -0.1711

  Batch 40 - PPO Epoch 2/4
    Policy Loss: 0.0012
    Value Loss: 0.0474
    KL Div: -0.0000
    Entropy: 0.0507
    Avg Reward: -0.1711

  Batch 40 - PPO Epoch 3/4
    Policy Loss: 0.0017
    Value Loss: 0.0480
    KL Div: -0.0000
    Entropy: 0.0510
    Avg Reward: -0.1711


Rollouts:  87%|████████▋ | 41/47 [16:27<02:24, 24.04s/it]


  Batch 40 - PPO Epoch 4/4
    Policy Loss: -0.0003
    Value Loss: 0.0481
    KL Div: 0.0005
    Entropy: 0.0508
    Avg Reward: -0.1711


Rollouts: 100%|██████████| 47/47 [18:33<00:00, 23.68s/it]



Epoch 4 Average Reward: -0.1753

-----------------------------
Epoch 5/8


Rollouts:   0%|          | 0/47 [00:00<?, ?it/s]


  Batch 0 - PPO Epoch 1/4
    Policy Loss: 0.0011
    Value Loss: 0.1063
    KL Div: -0.0001
    Entropy: 0.0529
    Avg Reward: -0.0922

  Batch 0 - PPO Epoch 2/4
    Policy Loss: 1.8145
    Value Loss: 0.1004
    KL Div: -0.0003
    Entropy: 0.0517
    Avg Reward: -0.0922

  Batch 0 - PPO Epoch 3/4
    Policy Loss: 0.0147
    Value Loss: 0.0981
    KL Div: -0.0002
    Entropy: 0.0517
    Avg Reward: -0.0922


Rollouts:   2%|▏         | 1/47 [00:24<18:27, 24.08s/it]


  Batch 0 - PPO Epoch 4/4
    Policy Loss: -0.0013
    Value Loss: 0.0913
    KL Div: -0.0002
    Entropy: 0.0516
    Avg Reward: -0.0922


Rollouts:  21%|██▏       | 10/47 [04:01<14:51, 24.11s/it]


  Batch 10 - PPO Epoch 1/4
    Policy Loss: 0.0106
    Value Loss: 0.0015
    KL Div: 0.0001
    Entropy: 0.0489
    Avg Reward: -0.2551

  Batch 10 - PPO Epoch 2/4
    Policy Loss: 0.0017
    Value Loss: 0.0017
    KL Div: 0.0000
    Entropy: 0.0487
    Avg Reward: -0.2551

  Batch 10 - PPO Epoch 3/4
    Policy Loss: 0.0013
    Value Loss: 0.0022
    KL Div: 0.0000
    Entropy: 0.0487
    Avg Reward: -0.2551


Rollouts:  23%|██▎       | 11/47 [04:25<14:27, 24.09s/it]


  Batch 10 - PPO Epoch 4/4
    Policy Loss: 0.0032
    Value Loss: 0.0028
    KL Div: 0.0000
    Entropy: 0.0489
    Avg Reward: -0.2551


Rollouts:  43%|████▎     | 20/47 [08:02<10:50, 24.08s/it]


  Batch 20 - PPO Epoch 1/4
    Policy Loss: 0.0046
    Value Loss: 0.0495
    KL Div: 0.0000
    Entropy: 0.0491
    Avg Reward: -0.1747

  Batch 20 - PPO Epoch 2/4
    Policy Loss: 12.9168
    Value Loss: 0.0447
    KL Div: 0.0000
    Entropy: 0.0491
    Avg Reward: -0.1747

  Batch 20 - PPO Epoch 3/4
    Policy Loss: 0.0014
    Value Loss: 0.0470
    KL Div: 0.0000
    Entropy: 0.0489
    Avg Reward: -0.1747


Rollouts:  45%|████▍     | 21/47 [08:26<10:26, 24.10s/it]


  Batch 20 - PPO Epoch 4/4
    Policy Loss: 0.0022
    Value Loss: 0.0430
    KL Div: 0.0000
    Entropy: 0.0489
    Avg Reward: -0.1747


Rollouts:  64%|██████▍   | 30/47 [12:03<06:49, 24.08s/it]


  Batch 30 - PPO Epoch 1/4
    Policy Loss: 0.0185
    Value Loss: 0.0176
    KL Div: 0.0002
    Entropy: 0.0482
    Avg Reward: -0.2532

  Batch 30 - PPO Epoch 2/4
    Policy Loss: 0.0109
    Value Loss: 0.0199
    KL Div: 0.0001
    Entropy: 0.0481
    Avg Reward: -0.2532

  Batch 30 - PPO Epoch 3/4
    Policy Loss: 0.0060
    Value Loss: 0.0138
    KL Div: 0.0001
    Entropy: 0.0482
    Avg Reward: -0.2532


Rollouts:  66%|██████▌   | 31/47 [12:27<06:25, 24.09s/it]


  Batch 30 - PPO Epoch 4/4
    Policy Loss: 0.0052
    Value Loss: 0.0085
    KL Div: 0.0000
    Entropy: 0.0483
    Avg Reward: -0.2532


Rollouts:  85%|████████▌ | 40/47 [16:03<02:48, 24.08s/it]


  Batch 40 - PPO Epoch 1/4
    Policy Loss: 0.0089
    Value Loss: 0.0030
    KL Div: 0.0007
    Entropy: 0.0606
    Avg Reward: -0.2532

  Batch 40 - PPO Epoch 2/4
    Policy Loss: 0.0056
    Value Loss: 0.0070
    KL Div: 0.0020
    Entropy: 0.0730
    Avg Reward: -0.2532

  Batch 40 - PPO Epoch 3/4
    Policy Loss: 0.0028
    Value Loss: 0.0050
    KL Div: 0.0026
    Entropy: 0.0601
    Avg Reward: -0.2532


Rollouts:  87%|████████▋ | 41/47 [16:28<02:24, 24.08s/it]


  Batch 40 - PPO Epoch 4/4
    Policy Loss: -0.0002
    Value Loss: 0.0039
    KL Div: 0.0004
    Entropy: 0.0614
    Avg Reward: -0.2532


Rollouts: 100%|██████████| 47/47 [18:33<00:00, 23.70s/it]



Epoch 5 Average Reward: -0.1753

-----------------------------
Epoch 6/8


Rollouts:   0%|          | 0/47 [00:00<?, ?it/s]


  Batch 0 - PPO Epoch 1/4
    Policy Loss: 0.0002
    Value Loss: 0.1104
    KL Div: -0.0001
    Entropy: 0.0518
    Avg Reward: -0.1736

  Batch 0 - PPO Epoch 2/4
    Policy Loss: -0.0006
    Value Loss: 0.0408
    KL Div: -0.0001
    Entropy: 0.0515
    Avg Reward: -0.1736

  Batch 0 - PPO Epoch 3/4
    Policy Loss: 0.0002
    Value Loss: 0.0561
    KL Div: -0.0002
    Entropy: 0.0507
    Avg Reward: -0.1736


Rollouts:   2%|▏         | 1/47 [00:24<18:28, 24.09s/it]


  Batch 0 - PPO Epoch 4/4
    Policy Loss: -0.0011
    Value Loss: 0.0410
    KL Div: -0.0002
    Entropy: 0.0503
    Avg Reward: -0.1736


Rollouts:  21%|██▏       | 10/47 [04:00<14:50, 24.06s/it]


  Batch 10 - PPO Epoch 1/4
    Policy Loss: 0.0056
    Value Loss: 0.0023
    KL Div: 0.0001
    Entropy: 0.0488
    Avg Reward: -0.2507

  Batch 10 - PPO Epoch 2/4
    Policy Loss: 0.0071
    Value Loss: 0.0020
    KL Div: 0.0001
    Entropy: 0.0487
    Avg Reward: -0.2507

  Batch 10 - PPO Epoch 3/4
    Policy Loss: 0.0070
    Value Loss: 0.0047
    KL Div: 0.0001
    Entropy: 0.0492
    Avg Reward: -0.2507


Rollouts:  23%|██▎       | 11/47 [04:24<14:26, 24.07s/it]


  Batch 10 - PPO Epoch 4/4
    Policy Loss: 0.0034
    Value Loss: 0.0042
    KL Div: 0.0002
    Entropy: 0.0499
    Avg Reward: -0.2507


Rollouts:  43%|████▎     | 20/47 [08:01<10:50, 24.08s/it]


  Batch 20 - PPO Epoch 1/4
    Policy Loss: 0.0024
    Value Loss: 0.0057
    KL Div: 0.0001
    Entropy: 0.0506
    Avg Reward: -0.2544

  Batch 20 - PPO Epoch 2/4
    Policy Loss: 0.0004
    Value Loss: 0.0035
    KL Div: 0.0003
    Entropy: 0.0526
    Avg Reward: -0.2544

  Batch 20 - PPO Epoch 3/4
    Policy Loss: -0.0011
    Value Loss: 0.0040
    KL Div: 0.0019
    Entropy: 0.0599
    Avg Reward: -0.2544


Rollouts:  45%|████▍     | 21/47 [08:25<10:26, 24.08s/it]


  Batch 20 - PPO Epoch 4/4
    Policy Loss: -0.0038
    Value Loss: 0.0018
    KL Div: 0.0018
    Entropy: 0.0567
    Avg Reward: -0.2544


Rollouts:  64%|██████▍   | 30/47 [12:02<06:48, 24.05s/it]


  Batch 30 - PPO Epoch 1/4
    Policy Loss: -0.0009
    Value Loss: 0.0743
    KL Div: 0.0001
    Entropy: 0.0496
    Avg Reward: -0.0898

  Batch 30 - PPO Epoch 2/4
    Policy Loss: -0.0007
    Value Loss: 0.0695
    KL Div: 0.0004
    Entropy: 0.0500
    Avg Reward: -0.0898

  Batch 30 - PPO Epoch 3/4
    Policy Loss: -0.0005
    Value Loss: 0.0665
    KL Div: 0.0004
    Entropy: 0.0501
    Avg Reward: -0.0898


Rollouts:  66%|██████▌   | 31/47 [12:26<06:24, 24.06s/it]


  Batch 30 - PPO Epoch 4/4
    Policy Loss: -0.0027
    Value Loss: 0.0580
    KL Div: 0.0008
    Entropy: 0.0510
    Avg Reward: -0.0898


Rollouts:  85%|████████▌ | 40/47 [16:02<02:48, 24.04s/it]


  Batch 40 - PPO Epoch 1/4
    Policy Loss: 0.0059
    Value Loss: 0.0033
    KL Div: 0.0000
    Entropy: 0.0471
    Avg Reward: -0.2532

  Batch 40 - PPO Epoch 2/4
    Policy Loss: 0.0093
    Value Loss: 0.0023
    KL Div: 0.0000
    Entropy: 0.0471
    Avg Reward: -0.2532

  Batch 40 - PPO Epoch 3/4
    Policy Loss: 0.0052
    Value Loss: 0.0031
    KL Div: 0.0000
    Entropy: 0.0471
    Avg Reward: -0.2532


Rollouts:  87%|████████▋ | 41/47 [16:26<02:24, 24.02s/it]


  Batch 40 - PPO Epoch 4/4
    Policy Loss: 0.0032
    Value Loss: 0.0026
    KL Div: 0.0000
    Entropy: 0.0471
    Avg Reward: -0.2532


Rollouts: 100%|██████████| 47/47 [18:32<00:00, 23.67s/it]



Epoch 6 Average Reward: -0.1753

-----------------------------
Epoch 7/8


Rollouts:   0%|          | 0/47 [00:00<?, ?it/s]


  Batch 0 - PPO Epoch 1/4
    Policy Loss: 0.0050
    Value Loss: 0.0856
    KL Div: 0.0000
    Entropy: 0.0472
    Avg Reward: -0.0911

  Batch 0 - PPO Epoch 2/4
    Policy Loss: 0.0032
    Value Loss: 0.0864
    KL Div: 0.0000
    Entropy: 0.0470
    Avg Reward: -0.0911

  Batch 0 - PPO Epoch 3/4
    Policy Loss: 0.0021
    Value Loss: 0.0927
    KL Div: 0.0000
    Entropy: 0.0471
    Avg Reward: -0.0911


Rollouts:   2%|▏         | 1/47 [00:24<18:30, 24.14s/it]


  Batch 0 - PPO Epoch 4/4
    Policy Loss: 0.0025
    Value Loss: 0.0800
    KL Div: 0.0000
    Entropy: 0.0471
    Avg Reward: -0.0911


Rollouts:  21%|██▏       | 10/47 [04:01<14:54, 24.18s/it]


  Batch 10 - PPO Epoch 1/4
    Policy Loss: 0.0021
    Value Loss: 0.0470
    KL Div: 0.0000
    Entropy: 0.0468
    Avg Reward: -0.1731

  Batch 10 - PPO Epoch 2/4
    Policy Loss: 0.0018
    Value Loss: 0.0550
    KL Div: 0.0000
    Entropy: 0.0469
    Avg Reward: -0.1731

  Batch 10 - PPO Epoch 3/4
    Policy Loss: -0.0005
    Value Loss: 0.0487
    KL Div: 0.0000
    Entropy: 0.0468
    Avg Reward: -0.1731


Rollouts:  23%|██▎       | 11/47 [04:25<14:29, 24.14s/it]


  Batch 10 - PPO Epoch 4/4
    Policy Loss: 0.0003
    Value Loss: 0.0478
    KL Div: 0.0000
    Entropy: 0.0468
    Avg Reward: -0.1731


Rollouts:  43%|████▎     | 20/47 [08:02<10:51, 24.13s/it]


  Batch 20 - PPO Epoch 1/4
    Policy Loss: -0.0001
    Value Loss: 0.0340
    KL Div: 0.0005
    Entropy: 0.0571
    Avg Reward: -0.2520

  Batch 20 - PPO Epoch 2/4
    Policy Loss: -0.0014
    Value Loss: 0.0161
    KL Div: 0.0016
    Entropy: 0.0649
    Avg Reward: -0.2520

  Batch 20 - PPO Epoch 3/4
    Policy Loss: -0.0005
    Value Loss: 0.0128
    KL Div: 0.0495
    Entropy: 0.1489
    Avg Reward: -0.2520


Rollouts:  45%|████▍     | 21/47 [08:26<10:27, 24.13s/it]


  Batch 20 - PPO Epoch 4/4
    Policy Loss: 0.0200
    Value Loss: 0.0054
    KL Div: 0.1227
    Entropy: 0.1862
    Avg Reward: -0.2520


Rollouts:  64%|██████▍   | 30/47 [12:03<06:49, 24.10s/it]


  Batch 30 - PPO Epoch 1/4
    Policy Loss: 0.0035
    Value Loss: 0.0438
    KL Div: 0.0014
    Entropy: 0.0753
    Avg Reward: -0.1741

  Batch 30 - PPO Epoch 2/4
    Policy Loss: 0.0042
    Value Loss: 0.0325
    KL Div: 0.0027
    Entropy: 0.0799
    Avg Reward: -0.1741

  Batch 30 - PPO Epoch 3/4
    Policy Loss: 0.0048
    Value Loss: 0.0258
    KL Div: 0.0018
    Entropy: 0.0749
    Avg Reward: -0.1741


Rollouts:  66%|██████▌   | 31/47 [12:27<06:25, 24.09s/it]


  Batch 30 - PPO Epoch 4/4
    Policy Loss: 0.0037
    Value Loss: 0.0209
    KL Div: 0.0015
    Entropy: 0.0730
    Avg Reward: -0.1741


Rollouts:  85%|████████▌ | 40/47 [16:05<02:49, 24.21s/it]


  Batch 40 - PPO Epoch 1/4
    Policy Loss: 0.0029
    Value Loss: 0.0089
    KL Div: 0.0000
    Entropy: 0.0515
    Avg Reward: -0.2512

  Batch 40 - PPO Epoch 2/4
    Policy Loss: -0.0010
    Value Loss: 0.0030
    KL Div: 0.0000
    Entropy: 0.0517
    Avg Reward: -0.2512

  Batch 40 - PPO Epoch 3/4
    Policy Loss: -0.0041
    Value Loss: 0.0034
    KL Div: 0.0001
    Entropy: 0.0516
    Avg Reward: -0.2512


Rollouts:  87%|████████▋ | 41/47 [16:29<02:25, 24.18s/it]


  Batch 40 - PPO Epoch 4/4
    Policy Loss: -0.0012
    Value Loss: 0.0014
    KL Div: 0.0001
    Entropy: 0.0520
    Avg Reward: -0.2512


Rollouts: 100%|██████████| 47/47 [18:34<00:00, 23.72s/it]



Epoch 7 Average Reward: -0.1747

-----------------------------
Epoch 8/8


Rollouts:   0%|          | 0/47 [00:00<?, ?it/s]


  Batch 0 - PPO Epoch 1/4
    Policy Loss: 0.0031
    Value Loss: 0.0905
    KL Div: 0.0000
    Entropy: 0.0510
    Avg Reward: -0.0908

  Batch 0 - PPO Epoch 2/4
    Policy Loss: 0.0104
    Value Loss: 0.0803
    KL Div: -0.0000
    Entropy: 0.0510
    Avg Reward: -0.0908

  Batch 0 - PPO Epoch 3/4
    Policy Loss: 0.0033
    Value Loss: 0.0790
    KL Div: 0.0000
    Entropy: 0.0511
    Avg Reward: -0.0908


Rollouts:   2%|▏         | 1/47 [00:24<18:32, 24.18s/it]


  Batch 0 - PPO Epoch 4/4
    Policy Loss: 0.0005
    Value Loss: 0.0562
    KL Div: 0.0001
    Entropy: 0.0513
    Avg Reward: -0.0908


Rollouts:  21%|██▏       | 10/47 [04:00<14:51, 24.08s/it]


  Batch 10 - PPO Epoch 1/4
    Policy Loss: 0.0024
    Value Loss: 0.0094
    KL Div: 0.0001
    Entropy: 0.0490
    Avg Reward: -0.2539

  Batch 10 - PPO Epoch 2/4
    Policy Loss: 0.0247
    Value Loss: 0.0107
    KL Div: 0.0000
    Entropy: 0.0490
    Avg Reward: -0.2539

  Batch 10 - PPO Epoch 3/4
    Policy Loss: 0.0002
    Value Loss: 0.0035
    KL Div: 0.0000
    Entropy: 0.0496
    Avg Reward: -0.2539


Rollouts:  23%|██▎       | 11/47 [04:24<14:28, 24.11s/it]


  Batch 10 - PPO Epoch 4/4
    Policy Loss: -0.0003
    Value Loss: 0.0016
    KL Div: 0.0001
    Entropy: 0.0498
    Avg Reward: -0.2539


Rollouts:  43%|████▎     | 20/47 [08:01<10:48, 24.02s/it]


  Batch 20 - PPO Epoch 1/4
    Policy Loss: 0.0070
    Value Loss: 0.0048
    KL Div: 0.0000
    Entropy: 0.0480
    Avg Reward: -0.2581

  Batch 20 - PPO Epoch 2/4
    Policy Loss: 0.0022
    Value Loss: 0.0040
    KL Div: 0.0000
    Entropy: 0.0479
    Avg Reward: -0.2581

  Batch 20 - PPO Epoch 3/4
    Policy Loss: 0.0005
    Value Loss: 0.0018
    KL Div: 0.0002
    Entropy: 0.0483
    Avg Reward: -0.2581


Rollouts:  45%|████▍     | 21/47 [08:25<10:24, 24.03s/it]


  Batch 20 - PPO Epoch 4/4
    Policy Loss: 0.0001
    Value Loss: 0.0006
    KL Div: 0.0006
    Entropy: 0.0492
    Avg Reward: -0.2581


Rollouts:  64%|██████▍   | 30/47 [12:02<06:48, 24.05s/it]


  Batch 30 - PPO Epoch 1/4
    Policy Loss: 0.0014
    Value Loss: 0.1733
    KL Div: 0.0000
    Entropy: 0.0480
    Avg Reward: -0.0086

  Batch 30 - PPO Epoch 2/4
    Policy Loss: -0.0026
    Value Loss: 0.1350
    KL Div: 0.0003
    Entropy: 0.0486
    Avg Reward: -0.0086

  Batch 30 - PPO Epoch 3/4
    Policy Loss: -0.0044
    Value Loss: 0.1167
    KL Div: 0.0003
    Entropy: 0.0491
    Avg Reward: -0.0086


Rollouts:  66%|██████▌   | 31/47 [12:26<06:24, 24.04s/it]


  Batch 30 - PPO Epoch 4/4
    Policy Loss: -0.0018
    Value Loss: 0.1224
    KL Div: 0.0003
    Entropy: 0.0489
    Avg Reward: -0.0086


Rollouts:  85%|████████▌ | 40/47 [16:02<02:48, 24.01s/it]


  Batch 40 - PPO Epoch 1/4
    Policy Loss: -0.0007
    Value Loss: 0.0335
    KL Div: 0.0000
    Entropy: 0.0474
    Avg Reward: -0.1735

  Batch 40 - PPO Epoch 2/4
    Policy Loss: -0.0019
    Value Loss: 0.0354
    KL Div: 0.0000
    Entropy: 0.0475
    Avg Reward: -0.1735

  Batch 40 - PPO Epoch 3/4
    Policy Loss: -0.0015
    Value Loss: 0.0280
    KL Div: 0.0000
    Entropy: 0.0475
    Avg Reward: -0.1735


Rollouts:  87%|████████▋ | 41/47 [16:26<02:24, 24.02s/it]


  Batch 40 - PPO Epoch 4/4
    Policy Loss: -0.0031
    Value Loss: 0.0245
    KL Div: 0.0000
    Entropy: 0.0475
    Avg Reward: -0.1735


Rollouts: 100%|██████████| 47/47 [18:31<00:00, 23.65s/it]


Epoch 8 Average Reward: -0.1628






In [34]:
trained_policy.save_pretrained('/kaggle/working/best_policy/')

In [36]:
!zip -r '/kaggle/working/best_policy' '/kaggle/working/'

  adding: kaggle/working/ (stored 0%)
  adding: kaggle/working/best_policy/ (stored 0%)
  adding: kaggle/working/best_policy/README.md (deflated 65%)
  adding: kaggle/working/best_policy/adapter_config.json (deflated 55%)
  adding: kaggle/working/best_policy/adapter_model.safetensors

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


 (deflated 7%)
  adding: kaggle/working/.virtual_documents/ (stored 0%)


In [37]:
test_prompt = "What is 2 + 1?"
inputs = tokenizer(test_prompt, return_tensors="pt").to('cuda')
output = trained_policy.generate(
    **inputs, 
    max_new_tokens=3,
    do_sample=False  # Greedy for testing
)
print(tokenizer.decode(output[0], skip_special_tokens=True))

Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


What is 2 + 1?
-2


In [None]:
import optuna

In [None]:
def tune_ppo_trainer(
    policy_model: PeftModel, 
    value_model: PeftModel,
    ref_model: PeftModel,
    tokenizer: PreTrainedTokenizerBase,
    dataset: Dataset,
    policy_optimizer: torch.optim.AdamW,
    value_optimizer: torch.optim.AdamW,
    reward_model: Optional[Callable[[list[str], list[str]], list[float]]] = None,
    num_epochs=3,
    rollout_batch_size=32,
    mini_batch_size=8,
    ppo_epochs=4,
    clip_eps: float=0.1,
    kl_coef: float=0.01,
    value_coef: float=0.5,
    entropy_coef: float=0.01,
    gamma: float=0.99,
    lam: float=0.95,
    max_gen_len: int=128
):
    def collate_fn(batch):
        input_ids = torch.stack([torch.tensor(item['input_ids']) for item in batch])
        attention_mask = torch.stack([torch.tensor(item['attention_mask']) for item in batch])
        answers = [item['answer'] for item in batch]
        questions = [item['question'] for item in batch]
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'answers': answers,
            'questions': questions
        }
        
    dataloader = DataLoader(dataset, batch_size=rollout_batch_size, collate_fn=collate_fn,shuffle=True)
    all_epoch_rewards = []

    for epoch in range(num_epochs):
        print('-----------------------------')
        print(f"Epoch {epoch+1}/{num_epochs}")
        for b_idx, batch in enumerate(tqdm(dataloader, desc='Rollouts')):
            policy_model.eval()
            value_model.eval()
            ref_model.eval()
            
            rollout_buffer=[]
            epoch_rewards = []

            with torch.no_grad():
                input_ids = batch['input_ids'].to(policy_model.device)
                attention_mask = batch['attention_mask'].to(policy_model.device)
                answers = batch['answers']
                
                response_token = policy_model.generate(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    max_new_tokens=max_gen_len,
                    do_sample=True,
                    temperature=0.7,
                    top_p=0.9,
                    pad_token_id=tokenizer.pad_token_id
                )
                responses=tokenizer.batch_decode(response_token, skip_special_tokens=True)
                
                rewards = compute_reward(responses, answers)
                rewards = rewards.to(policy_model.device)
                batch_avg_reward = rewards.mean().item()
                epoch_rewards.append(batch_avg_reward)

                attention_mask = (response_token != tokenizer.pad_token_id).long()

                values = value_model(
                    input_ids=response_token,
                    attention_mask=attention_mask
                ).logits.squeeze(-1)

                old_log_probs, mask = get_log_probs(
                    policy_model, response_token, attention_mask
                )
                ref_log_probs,_ = get_log_probs(
                    ref_model, response_token, attention_mask
                )
                advantages, returns = compute_gae(rewards, values, gamma, lam)
                advantages = (advantages-advantages.mean()) / (advantages.std()+1e-8)

                for i in range(len(batch)):
                    rollout_buffer.append({
                        'input_ids': response_token[i].cpu(),
                        'attention_mask': attention_mask[i].cpu(),
                        'old_log_probs': old_log_probs[i].cpu(),
                        'ref_log_probs': ref_log_probs[i].cpu(),
                        'advantages': advantages[i].cpu(),
                        'returns': returns[i].cpu(),
                        'mask': mask[i].cpu()
                    })
            policy_model.train()
            value_model.train()

            for ppo_epoch in range(ppo_epochs):
                import random
                random.shuffle(rollout_buffer)
                num_minibatches = len(rollout_buffer)//mini_batch_size

                ppo_policy_losses =[]
                ppo_value_losses =[]
                ppo_kl_penalties = []
                ppo_entropies = []
                for mb_idx in range(num_minibatches):
                    start_idx = mb_idx* mini_batch_size
                    end_idx = start_idx + mini_batch_size 
                    minibatch = rollout_buffer[start_idx:end_idx]
                    
                    #stack minibatches
                    mb_input_ids = torch.stack([item['input_ids']for item in minibatch])
                    mb_attention_mask = torch.stack([item['attention_mask'] for item in minibatch])
                    mb_old_log_probs = torch.stack([item['old_log_probs'] for item in minibatch])
                    mb_ref_log_probs = torch.stack([item['ref_log_probs'] for item in minibatch])
                    mb_advantages = torch.stack([item['advantages'] for item in minibatch])
                    mb_returns = torch.stack([item['returns'] for item in minibatch])
                    mb_mask = torch.stack([item['mask'] for item in minibatch])

                    #move to device
                    mb_input_ids = mb_input_ids.to(policy_model.device)
                    mb_attention_mask = mb_attention_mask.to(policy_model.device)
                    mb_advantages = mb_advantages.to(policy_model.device)
                    mb_old_log_probs = mb_old_log_probs.to(policy_model.device)
                    mb_ref_log_probs = mb_ref_log_probs.to(policy_model.device)
                    mb_returns = mb_returns.to(policy_model.device)
                    mb_mask = mb_mask.to(policy_model.device)
                    
                    policy_optimizer.zero_grad()
                    new_log_probs, _ = get_log_probs(
                        policy_model,
                        mb_input_ids,
                        mb_attention_mask
                    )
                    
                    #PPO loss
                    ratio = torch.exp(new_log_probs - mb_old_log_probs)    
                    surr1 = ratio*mb_advantages*mb_mask
                    surr2 = torch.clamp(ratio, 1-clip_eps, 1+clip_eps)* mb_advantages * mb_mask
                    policy_loss = -torch.min(surr1, surr2).sum() / mb_mask.sum()
                    #KL 
                    kl_div = (new_log_probs - mb_ref_log_probs)*mb_mask
                    kl_penalty = kl_div.sum() / mb_mask.sum()
                    #entropy
                    outputs = policy_model(
                        input_ids=mb_input_ids,
                        attention_mask=mb_attention_mask
                    )
                    
                    entropy=compute_entropy(outputs.logits[:, :-1, :], mb_mask)

                    total_policy_loss = policy_loss + kl_coef * kl_penalty - entropy_coef * entropy.mean()
                    #backward
                    if not torch.isnan(total_policy_loss):
                        total_policy_loss.backward()
                        torch.nn.utils.clip_grad_norm_(policy_model.parameters(), max_norm=1.0)
                        policy_optimizer.step()

                    #value update
                    value_optimizer.zero_grad()
                    new_values = value_model(
                        input_ids=mb_input_ids,
                        attention_mask=mb_attention_mask
                    ).logits.squeeze(-1)
                    
                    # if new_values.dim() == 1:
                    #     seq_len = mb_input_ids.shape[1]
                    #     new_values = new_values.unsqueeze(-1).expand(-1, seq_len)
                    final_returns = mb_returns[:, -1]
                    #value loss
                    # value_loss = ((new_values-final_returns)**2*mb_mask).sum()/ mb_mask.sum()
                    value_loss = value_coef * ((new_values - final_returns) ** 2).mean()


                    # Backward and optimize
                    if not torch.isnan(value_loss):
                        value_loss.backward()
                        torch.nn.utils.clip_grad_norm_(value_model.parameters(), max_norm=1.0)
                        value_optimizer.step()

                    ppo_policy_losses.append(policy_loss.item())
                    ppo_value_losses.append(value_loss.item())
                    ppo_kl_penalties.append(kl_penalty.item())
                    ppo_entropies.append(entropy.mean().item())
                
                if b_idx % 10 == 0:
                    print(f"\n  Batch {b_idx} - PPO Epoch {ppo_epoch+1}/{ppo_epochs}")
                    print(f"    Policy Loss: {sum(ppo_policy_losses)/len(ppo_policy_losses):.4f}")
                    print(f"    Value Loss: {sum(ppo_value_losses)/len(ppo_value_losses):.4f}")
                    print(f"    KL Div: {sum(ppo_kl_penalties)/len(ppo_kl_penalties):.4f}")
                    print(f"    Entropy: {sum(ppo_entropies)/len(ppo_entropies):.4f}")
                    print(f"    Avg Reward: {batch_avg_reward:.4f}")
            
            del rollout_buffer
            torch.cuda.empty_cache()     
            
        epoch_avg_reward = sum(epoch_rewards) / len(epoch_rewards)
        all_epoch_rewards.append(epoch_avg_reward)
        print(f"\n{'='*50}")
        print(f"Epoch {epoch+1} Average Reward: {epoch_avg_reward:.4f}")
        print(f"{'='*50}\n")
    
    final_avg_reward = sum(all_epoch_rewards) / len(all_epoch_rewards)
    print(f"\n{'='*50}")
    print(f"FINAL AVERAGE REWARD: {final_avg_reward:.4f}")
    print(f"{'='*50}\n")

    return final_avg_reward

In [None]:
def objective(trial):

    clip_eps = trial.suggest_categorical('clip_eps', [0.1, 0.15, 0.2, 0.3])
    kl_coef = trial.suggest_categorical('kl_coef', [0.1, 0.2, 0.3, 0.4])
    value_coef = trial.suggest_categorical('value_coef', [0.5, 0.4, 0.6, 0.3])
    entropy_coef = trial.suggest_categorical('entropy_coef', [0.01, 0.03, 0.05, 0.04])
    gamma = trial.suggest_categorical('gamma', [0.99, 0.90, 0.85, 0.70])
    lam = trial.suggest_categorical('lam', [0.95, 0.87, 0.80, 0.76])
    # rollout_batch_size = trial.suggest_categorical('rollout_batch_size'. [8, 16, 32])
    # mini_batch_size = 
    policy_lr = trial.suggest_loguniform('policy_lr', 5e-5, 1e-3)
    value_lr = trial.suggest_loguniform('value_lr', 5e-5, 1e-3)
    
    print(f"\n{'='*60}")
    print(f"TRIAL {trial.number}")
    print(f"{'='*60}")
    print(f"Hyperparameters:")
    print(f"  clip_eps: {clip_eps}")
    print(f"  kl_coef: {kl_coef}")
    print(f"  value_coef: {value_coef}")
    print(f"  entropy_coef: {entropy_coef}")
    print(f"  gamma: {gamma}")
    print(f"  lam: {lam}")
    print(f"  policy_lr: {policy_lr:.2e}")
    print(f"  value_lr: {value_lr:.2e}")
    print(f"{'='*60}\n")

    trial_policy_model = PeftModel.from_pretrained(
        base_model,
        "/kaggle/input/loraadapters/pytorch/default/1",
        trainable=True
    )
    
    trial_value_model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        num_labels=1,
        torch_dtype=torch.float16
    )
    for n,p in trial_policy_model.named_parameters():
        if 'lora' in n:
            p.requires_grad=True
            
    trial_value_model.config.pad_token_id = tokenizer.pad_token_id
    trial_value_model = get_peft_model(trial_value_model, value_config)
    trial_value_model.to('cuda')
    
    policy_optimizer = torch.optim.AdamW(trial_policy_model.parameters(), lr=policy_lr)
    value_optimizer = torch.optim.AdamW(trial_value_model.parameters(), lr=value_lr)
    
    small_dataset = tokenized_dataset.select(range(min(200, len(tokenized_dataset))))
    try:
        average_reward = tune_ppo_trainer(
            policy_model=trial_policy_model,
            value_model=trial_value_model,
            ref_model=ref_model,
            tokenizer=tokenizer,
            dataset=small_dataset,
            policy_optimizer=policy_optimizer,
            value_optimizer=value_optimizer,
            reward_model=None,
            num_epochs=1,  # Just 1 epoch for tuning
            rollout_batch_size=8,  # Smaller for faster tuning
            mini_batch_size=4,
            ppo_epochs=4,
            clip_eps=clip_eps,
            kl_coef=kl_coef,
            value_coef=value_coef,
            entropy_coef=entropy_coef,
            gamma=gamma,
            lam=lam,
            max_gen_len=64  # Shorter generation for speed
        )
    except Exception as e:
        print(f"Trial {trial.number} failed with error: {e}")
        return -10.0
    finally:
        del trial_policy_model, trial_value_model
        torch.cuda.empty_cache()
    return average_reward

In [None]:
import optuna

In [None]:
print("Starting Optuna hyperparameter search...")
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=20, timeout=3600*6)

In [None]:
best_params = study.best_params

In [None]:
policy_optimizer = torch.optim.AdamW(policy_model.parameters(), lr=best_params['policy_lr'])
value_optimizer = torch.optim.AdamW(value_model.parameters(), lr=best_params['value_lr'])

In [None]:
best_params

In [None]:
trained_policy, trained_value = ppo_trainer(
    policy_model=trial_policy_model,
    value_model=trial_value_model,
    ref_model=ref_model,
    tokenizer=tokenizer,
    dataset=small_dataset,
    policy_optimizer=policy_optimizer,
    value_optimizer=value_optimizer,
    reward_model=None,
    num_epochs=8,  
    rollout_batch_size=32,  
    mini_batch_size=8,
    ppo_epochs=4,
    clip_eps=best_params['clip_eps'],
    kl_coef=best_params['kl_coef'],
    value_coef=best_params['value_coef'],
    entropy_coef=best_params['entropy_coef'],
    gamma=best_params['gamma'],
    lam=best_params['lam'],
    max_gen_len=128 
)
