In [None]:
# !pip install --upgrade torch torchvision torchaudio

In [1]:
# !pip install datasets
# !pip install pandas
# !pip install transformers
# !pip install sentencepiece
# !pip install tqdm
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
from vlm_copy import VisionLanguageModel

  from .autonotebook import tqdm as notebook_tqdm


### Dataset_reward

In [4]:
from datasets import load_dataset

# Login using e.g. `huggingface-cli login` to access this dataset
ds2 = load_dataset("MMInstruction/VL-RewardBench")

In [5]:
ds2 

DatasetDict({
    test: Dataset({
        features: ['id', 'query', 'response', 'image', 'human_ranking', 'models', 'judge', 'rationale', 'query_source', 'ground_truth'],
        num_rows: 1250
    })
})

In [6]:
ds2 = ds2.filter(lambda x: x['human_ranking'] == [0,1])

In [7]:
len(ds2['test'])

1244

In [8]:
ds2 = ds2['test'].map(lambda x: {'image': x['image'], 'prompt': x['query'], 'chosen': x['response'][0], 'rejected': x['response'][1]}
                      , remove_columns=['human_ranking', 'response', 'id','models','judge','rationale', 'ground_truth','query_source','query'])

In [9]:
ds2

Dataset({
    features: ['image', 'prompt', 'chosen', 'rejected'],
    num_rows: 1244
})

In [10]:
import sentencepiece as spm
tokenizer = spm.SentencePieceProcessor(model_file='spm.model')

In [11]:
# Image preprocessing
import torchvision.transforms as transforms
image_transform = transforms.Compose([
    transforms.Resize((96, 96)),
    transforms.ToTensor(),
])

In [12]:
# Dạng dữ liệu mới: [{'image': PIL.Image, 'prompt': str, 'caption': str}, ...]
from PIL import Image
import numpy as np
import torch

def preprocess2(example):
    img_data = example["image"]

    # Đảm bảo ảnh là PIL.Image
    if isinstance(img_data, Image.Image):
        img = img_data.convert("RGB")
    elif isinstance(img_data, np.ndarray):
        img = Image.fromarray(img_data).convert("RGB")
    else:
        raise ValueError(f"Unsupported image format: {type(img_data)}")

    # Transform ảnh sang Tensor (3, 224, 224)
    image = image_transform(img)
    chosen = example["prompt"] + " " + example["chosen"]
    rejected = example["prompt"] + " " + example["rejected"]

    pad_id = tokenizer.pad_id() if tokenizer.pad_id() >= 0 else 0
    chosen_input_ids = tokenizer.encode(chosen)
    chosen_tokens = chosen_input_ids[:256]
    chosen_tokens += [pad_id] * (256 - len(chosen_tokens))
    chosen_input_ids = torch.tensor(chosen_tokens, dtype=torch.long)

    rejected_input_ids = tokenizer.encode(rejected)
    rejected_tokens = rejected_input_ids[:256]
    rejected_tokens += [pad_id] * (256 - len(rejected_tokens))
    rejected_input_ids = torch.tensor(rejected_tokens, dtype=torch.long)
    return {
        "image": image,
        "chosen_input_ids": chosen_input_ids,
        "reject_input_ids": rejected_input_ids
    }

dataset2 = list(map(preprocess2, ds2))



In [13]:
def collate_fn2(batch):
    imgs = torch.stack([torch.tensor(item['image']) if not isinstance(item['image'], torch.Tensor) else item['image'] for item in batch])
    chosen_input_ids = torch.stack([item["chosen_input_ids"] for item in batch])
    reject_input_ids = torch.stack([item["reject_input_ids"] for item in batch])
    return imgs, chosen_input_ids, reject_input_ids


In [14]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
device

device(type='cpu')

# 1.1 SFT base model

In [15]:
n_embed, num_hiddens, num_heads, n_layer = 128, 512, 8, 8
image_embed_dim = num_hiddens
img_size = 96
patch_size = 16
num_blocks = 2

n_layer, block_size, num_hiddens = 8, 32, 512

# Initialize the model
vlm = VisionLanguageModel(
    n_embed=n_embed,
    img_embed_dim=image_embed_dim,
    vocab_size=tokenizer.vocab_size(),
    n_layer=n_layer,
    img_size=img_size,
    patch_size=patch_size,
    num_heads=num_heads,
    num_blocks=num_blocks,
    emb_dropout=0.1,
    block_dropout=0.1,
)
device = torch.device('cpu')
vlm.to(device)

# vlm.load_state_dict(torch.load("./checkpoints/vlm_best.pt")['model_state_dict'])
# vlm.eval()  # set to eval mode if you're going to do inference

VisionLanguageModel(
  (vision_encoder): ViT(
    (patch_embedding): PatchEmbeddings(
      (conv): Conv2d(3, 512, kernel_size=(16, 16), stride=(16, 16))
    )
    (dropout): Dropout(p=0.1, inplace=False)
    (blocks): ModuleList(
      (0-1): 2 x Block(
        (ln1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (mhattn): MultiHeadAttention(
          (heads): ModuleList(
            (0-7): 8 x Head(
              (key): Linear(in_features=512, out_features=64, bias=False)
              (query): Linear(in_features=512, out_features=64, bias=False)
              (value): Linear(in_features=512, out_features=64, bias=False)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (proj): Linear(in_features=512, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (ln2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (ffn): Sequential(
          (0): Linear(in_features=512, out_

## Reward

In [20]:
# Create DataLoader
from torch.utils.data import DataLoader
dataloader2 = DataLoader(dataset2, batch_size=32, shuffle=True, collate_fn=collate_fn2)
dataset2[4]

{'image': tensor([[[0.3176, 0.4314, 0.4980,  ..., 0.8863, 0.8706, 0.8588],
          [0.4510, 0.4745, 0.4863,  ..., 0.8980, 0.8863, 0.8706],
          [0.4706, 0.4039, 0.4000,  ..., 0.9137, 0.9059, 0.8902],
          ...,
          [0.5451, 0.5647, 0.6157,  ..., 0.6745, 0.6431, 0.6314],
          [0.5412, 0.5608, 0.6039,  ..., 0.6510, 0.6314, 0.6471],
          [0.5686, 0.5843, 0.6118,  ..., 0.6431, 0.6235, 0.6392]],
 
         [[0.3333, 0.4588, 0.5647,  ..., 0.9961, 0.9922, 0.9922],
          [0.5098, 0.5529, 0.5686,  ..., 0.9961, 0.9922, 0.9922],
          [0.5608, 0.4863, 0.4588,  ..., 0.9922, 0.9922, 0.9922],
          ...,
          [0.1725, 0.1882, 0.2000,  ..., 0.6314, 0.6039, 0.5922],
          [0.1765, 0.1843, 0.1882,  ..., 0.6078, 0.5922, 0.6078],
          [0.1922, 0.2000, 0.1882,  ..., 0.6000, 0.5843, 0.6000]],
 
         [[0.3686, 0.5137, 0.6431,  ..., 0.9882, 0.9961, 0.9922],
          [0.5725, 0.6314, 0.6549,  ..., 0.9882, 0.9961, 0.9882],
          [0.6353, 0.5490, 0.51

In [None]:
import torch
import torch.nn as nn

class RewardModel(nn.Module):
    def __init__(self, base_model: VisionLanguageModel, hidden_dim=128):
        super().__init__()
        self.base_model = base_model  
        self.reward_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        self.gain = nn.Parameter(torch.ones(1))
        self.bias = nn.Parameter(torch.zeros(1))

    def forward(self, img_array: torch.Tensor, input_ids: torch.Tensor):
        image_embeds = self.base_model.vision_encoder(img_array)
        # Get hidden states from decoder
        hidden_states = self.base_model.decoder(
            idx=input_ids,
            img_embeds=image_embeds,
            return_hidden=True  # This is the key addition
        )
        last_hidden = hidden_states[:, -1, :]  # last token representation
        reward = self.reward_head(last_hidden).squeeze(-1)
        return self.gain * reward + self.bias          #???

In [18]:
# Wrap in RewardModel
reward_model = RewardModel(base_model=vlm, hidden_dim=128).to(device)

# Example input
img_tensor = torch.randn(4, 3, 96, 96).to(device)      
text_input = torch.randint(0, 30522, (4, 64)).to(device)  

print(img_tensor.shape)
print(text_input.shape)

# Forward pass to get reward values
rewards = reward_model(img_tensor, text_input)
print("Rewards:", rewards)


torch.Size([4, 3, 96, 96])
torch.Size([4, 64])
Rewards: tensor([0.2629, 0.3891, 0.2147, 0.1775], grad_fn=<AddBackward0>)


In [19]:
optimizer = torch.optim.Adam(reward_model.parameters(), lr=1e-5)

In [24]:
def train_reward(reward_model, dataloader2, optimizer, num_epochs=100):
    for epoch in range(num_epochs):
        total_loss = 0
        for batch in tqdm(dataloader2, desc=f"Epoch {epoch+1}/{num_epochs}"):
            image, response_a, response_b = batch
            
            # print(image.shape)                  # torch.Size([32, 3, 96, 96])
            # print(response_a.shape)             # torch.Size([32, 256])
            # print(response_b.shape)             # torch.Size([32, 256])

            r_a = reward_model(image, response_a)
            r_b = reward_model(image, response_b)

            loss = -torch.nn.functional.logsigmoid(r_a - r_b).mean()
            total_loss += loss.item()

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        avg_loss = total_loss / len(dataloader2)
        print(f"[Epoch {epoch+1}] Loss: {avg_loss:.4f}")

In [25]:
train_reward(reward_model, dataloader2, optimizer, num_epochs=10)

Epoch 1/10: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 39/39 [04:02<00:00,  6.23s/it]


[Epoch 1] Loss: 0.6918


Epoch 2/10: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 39/39 [03:54<00:00,  6.02s/it]


[Epoch 2] Loss: 0.6876


Epoch 3/10: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 39/39 [03:49<00:00,  5.89s/it]


[Epoch 3] Loss: 0.6843


Epoch 4/10: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 39/39 [03:40<00:00,  5.65s/it]


[Epoch 4] Loss: 0.6830


Epoch 5/10: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 39/39 [03:40<00:00,  5.65s/it]


[Epoch 5] Loss: 0.6793


Epoch 6/10: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 39/39 [03:40<00:00,  5.66s/it]


[Epoch 6] Loss: 0.6771


Epoch 7/10: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 39/39 [03:40<00:00,  5.66s/it]


[Epoch 7] Loss: 0.6765


Epoch 8/10: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 39/39 [03:40<00:00,  5.65s/it]


[Epoch 8] Loss: 0.6741


Epoch 9/10: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 39/39 [03:40<00:00,  5.66s/it]


[Epoch 9] Loss: 0.6719


Epoch 10/10: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 39/39 [03:41<00:00,  5.67s/it]

[Epoch 10] Loss: 0.6694





In [26]:
torch.save(reward_model.state_dict(), "reward_model.pth")

In [27]:
sample = dataset2[3]

In [28]:
print(sample)

{'image': tensor([[[0.1529, 0.1333, 0.1255,  ..., 0.2118, 0.1922, 0.1686],
         [0.1373, 0.1333, 0.1412,  ..., 0.2000, 0.2078, 0.2235],
         [0.1490, 0.1647, 0.1686,  ..., 0.1843, 0.2078, 0.2275],
         ...,
         [0.1294, 0.1294, 0.1137,  ..., 0.1765, 0.1569, 0.1569],
         [0.1294, 0.1294, 0.1255,  ..., 0.1765, 0.1686, 0.1529],
         [0.1294, 0.1216, 0.1216,  ..., 0.1725, 0.1725, 0.1725]],

        [[0.1333, 0.1216, 0.1333,  ..., 0.1922, 0.1686, 0.1451],
         [0.1216, 0.1255, 0.1490,  ..., 0.2000, 0.2039, 0.2039],
         [0.1373, 0.1608, 0.1804,  ..., 0.1882, 0.2078, 0.2118],
         ...,
         [0.1412, 0.1412, 0.1255,  ..., 0.1804, 0.1686, 0.1765],
         [0.1412, 0.1412, 0.1412,  ..., 0.1725, 0.1843, 0.1686],
         [0.1412, 0.1333, 0.1373,  ..., 0.1686, 0.1882, 0.1843]],

        [[0.3373, 0.3412, 0.3333,  ..., 0.4039, 0.3922, 0.4353],
         [0.3216, 0.3333, 0.3373,  ..., 0.4431, 0.4549, 0.5294],
         [0.3294, 0.3490, 0.3412,  ..., 0.4588, 

In [29]:
img = sample['image'].unsqueeze(0)
r_a = sample['chosen_input_ids'].unsqueeze(0)
r_b = sample['reject_input_ids'].unsqueeze(0)
print(img.shape)
print(r_a.shape)
print(r_b.shape)

torch.Size([1, 3, 96, 96])
torch.Size([1, 256])
torch.Size([1, 256])


In [30]:
device = torch.device("cpu")
img = img.to(device)
r_a = r_a.to(device)
r_b = r_b.to(device)

reward_model = RewardModel(base_model=vlm, hidden_dim=128)  # make sure dimensions match training

# Load weights
reward_model.load_state_dict(torch.load("reward_model.pth"))
reward_model.eval()
reward_model.to(device)

# --- Run inference ---
with torch.no_grad():
    reward1 = reward_model(img, r_a)
    reward2 = reward_model(img, r_b)


In [31]:
print(reward1.item())
print(reward2.item())

1.4765957593917847
1.334306240081787


# PPO

In [None]:
# from vlm import VisionLanguageModel

In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

def ppo_prepare(
    vlm_model,
    reward_model,
    dataloader,
    batch_size=32,
    gamma=0.99, # gamma and lam for GAE but we not use GAE
    lam=0.95
):
    # Step 1: Collect old log probs, values, and generated responses
    vlm_model.eval()
    old_log_probs_list = []
    old_values_list = []
    generated_responses_list = []

    with torch.no_grad():
        for batch in dataloader:
            images, input_ids, targets = batch

            logits, values = vlm_model(images, input_ids) 
            values = values[:, -1]  # [B]
            print(values.shape)
            log_probs = F.log_softmax(logits, dim=-1)   # [B, T, 128000] 128000 is vocabsize
            target_log_probs = log_probs.gather(2, targets.unsqueeze(-1)).squeeze(-1)     # unsqueeze de phu hop voi dau vao cua ham gather, 2 la dimension 2 ([B, T, 1] --> T)

            old_log_probs_list.append(target_log_probs)
            old_values_list.append(values.squeeze(-1))  # shape: [B]
            print(len(old_values_list[0]))

            # responses = vlm_model(images, input_ids, max_new_tokens=50)  # Generated responses (tokens or ids), goi ham generated() thay vi forward
            responses = torch.argmax(logits, dim=-1)
            full_responses = torch.cat([input_ids, responses], dim=1)  # shape: [B, T+max_new_tokens]
            generated_responses_list.append(full_responses.view(-1))

    old_log_probs = torch.cat(old_log_probs_list, dim=0)       # shape: [B, T]
    old_values = torch.cat(old_values_list, dim=0)             # shape: [B, T+1]
    generated_responses = torch.cat(generated_responses_list, dim=0)  # shape: [B, <= T+max_new_tokens]

    # Step 2: Compute scalar rewards
    rewards = reward_model(images, generated_responses).detach()  # shape: [B]

    # Step 3: Compute advantages without GAE (if u want to use GAE, the rewards should be in shape [B, T])
    advantages = rewards - old_values

    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

    # Step 4: Prepare dataset
    dataset = TensorDataset(images, input_ids, targets, old_log_probs, old_values, rewards, advantages)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    return dataloader

In [38]:
batch_size = 32
num_samples = 10  # Example sample size
images = torch.randn(num_samples, 3, 96, 96)  # Example image tensor
input_ids = torch.randint(0, tokenizer.vocab_size(), (num_samples, 256))  # Example tokenized text (e.g., question)
targets = torch.randint(0, tokenizer.vocab_size(), (num_samples, 256))  # Example target indices (e.g., word IDs for response)

# Create DataLoader (for batching)
dataset = TensorDataset(images, input_ids, targets)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [89]:
dataloader_ppo = ppo_prepare(vlm, reward_model, dataloader, batch_size=32, gamma= 0.99, lam= 0.95)

torch.Size([10])
10


In [93]:
print(torch.tensor([0, 1, 2,3,4,5,6,7,8,9]).squeeze(-1))

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])


In [None]:
def ppo_train_with_kl(
    vlm_model,
    optimizer,
    dataloader, 
    clip_ratio=0.2,
    ppo_epochs=10,
    kl_coeff=0.01
):
    # Step 5: PPO optimization loop
    vlm_model.train()
    for epoch in range(ppo_epochs):
        for batch in dataloader:
            batch_images, batch_input_ids, batch_targets, old_log_probs_batch, old_values_batch, rewards_batch, adv_batch = batch

            logits, values = vlm_model(batch_images, batch_input_ids)
            values = values[:, -1]
            log_probs = F.log_softmax(logits, dim=-1)
            new_log_probs = log_probs.gather(2, batch_targets.unsqueeze(-1)).squeeze(-1)

            # PPO ratio
            new_log_probs = new_log_probs[:, -1]           # shape [B]
            old_log_probs_batch = old_log_probs_batch[:, -1]  # shape [B]
            ratio = torch.exp(new_log_probs - old_log_probs_batch)

            # PPO Clipped Objective
            clipped_ratio = torch.clamp(ratio, 1 - clip_ratio, 1 + clip_ratio)
            ppo_objective = torch.min(ratio * adv_batch, clipped_ratio * adv_batch)
            ppo_loss = -ppo_objective.mean()

            # KL divergence: D_KL(old || new)
            kl_div = (old_log_probs_batch.exp() * (old_log_probs_batch - new_log_probs)).mean()

            # Value loss
            value_loss = F.mse_loss(values.squeeze(-1), rewards_batch)

            # Total loss
            total_loss = ppo_loss + value_loss + kl_coeff * kl_div

            # Backpropagation
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

        print(f"[Epoch {epoch + 1}] PPO Loss: {ppo_loss.item():.4f}, "
              f"Value Loss: {value_loss.item():.4f}, KL: {kl_div.item():.6f}")


In [73]:
n_embed, num_hiddens, num_heads, n_layer = 128, 512, 8, 8
image_embed_dim = num_hiddens
img_size = 96
patch_size = 16
num_blocks = 2

n_layer, block_size, num_hiddens = 8, 32, 512

# Initialize the model
vlm = VisionLanguageModel(
    n_embed=n_embed,
    img_embed_dim=image_embed_dim,
    vocab_size=tokenizer.vocab_size(),
    n_layer=n_layer,
    img_size=img_size,
    patch_size=patch_size,
    num_heads=num_heads,
    num_blocks=num_blocks,
    emb_dropout=0.1,
    block_dropout=0.1,
)
device = torch.device('cpu')
vlm.to(device)

reward_model = RewardModel(base_model=vlm, hidden_dim=128)
optimizer = optim.Adam(vlm.parameters(), lr=1e-5)

In [85]:
ppo_train_with_kl(vlm, optimizer, dataloader_ppo, clip_ratio=0.2, ppo_epochs=100, kl_coeff=0.01)

[Epoch 1] PPO Loss: -0.1069, Value Loss: 0.0385, KL: 0.000001
[Epoch 2] PPO Loss: -0.1119, Value Loss: 0.0216, KL: 0.000001
[Epoch 3] PPO Loss: -0.1185, Value Loss: 0.0236, KL: 0.000001
[Epoch 4] PPO Loss: -0.1298, Value Loss: 0.0284, KL: 0.000001
[Epoch 5] PPO Loss: -0.1341, Value Loss: 0.0222, KL: 0.000001
[Epoch 6] PPO Loss: -0.1359, Value Loss: 0.0185, KL: 0.000001
[Epoch 7] PPO Loss: -0.1388, Value Loss: 0.0195, KL: 0.000001
[Epoch 8] PPO Loss: -0.1355, Value Loss: 0.0138, KL: 0.000001
[Epoch 9] PPO Loss: -0.1409, Value Loss: 0.0135, KL: 0.000001
[Epoch 10] PPO Loss: -0.1377, Value Loss: 0.0111, KL: 0.000001
[Epoch 11] PPO Loss: -0.1402, Value Loss: 0.0113, KL: 0.000001
[Epoch 12] PPO Loss: -0.1394, Value Loss: 0.0154, KL: 0.000001
[Epoch 13] PPO Loss: -0.1400, Value Loss: 0.0121, KL: 0.000001
[Epoch 14] PPO Loss: -0.1400, Value Loss: 0.0089, KL: 0.000001
[Epoch 15] PPO Loss: -0.1402, Value Loss: 0.0098, KL: 0.000001
[Epoch 16] PPO Loss: -0.1417, Value Loss: 0.0032, KL: 0.000001
[

In [None]:
# import torch
# import torch.nn.functional as F
# from torch.utils.data import DataLoader, TensorDataset

# def ppo_train_with_kl(
#     model,
#     reward_model,
#     optimizer,
#     images,
#     input_ids,
#     targets, 
#     clip_ratio=0.2,
#     ppo_epochs=4,
#     batch_size=32,
#     gamma=0.99,
#     lam=0.95,
#     kl_coeff=0.01
# ):
#     # Step 1: Collect old log probs, values, and generated responses
#     model.eval()
#     old_log_probs_list = []
#     old_values_list = []
#     generated_responses_list = []

#     with torch.no_grad():
#         for i in range(0, len(input_ids), batch_size):
#             batch_images = images[i:i + batch_size]
#             batch_input_ids = input_ids[i:i + batch_size]

#             logits, values = model(batch_images, batch_input_ids) 
#             log_probs = F.log_softmax(logits, dim=-1)
#             target_log_probs = log_probs.gather(2, batch_targets.unsqueeze(-1)).squeeze(-1)

#             old_log_probs_list.append(target_log_probs)
#             old_values_list.append(values.squeeze(-1))  # shape: [B]

#             responses = model(batch_images, batch_input_ids)  # Generated responses (tokens or ids), goi ham generated() thay vi forward
#             generated_responses_list.append(responses)

#     old_log_probs = torch.cat(old_log_probs_list, dim=0)       # shape: [N, T]
#     old_values = torch.cat(old_values_list, dim=0)             # shape: [N]
#     generated_responses = torch.cat(generated_responses_list, dim=0)  # shape: [N, T]

#     # Step 2: Compute scalar rewards
#     rewards = reward_model(images, generated_responses).detach()  # shape: [N]

#     # Step 3: Compute advantages using GAE stragy
#     with torch.no_grad():
#         last_value = old_values[-1].unsqueeze(0)
#         all_values = torch.cat([old_values, last_value], dim=0)

#     advantages = torch.zeros_like(rewards)
#     gae = 0
#     for t in reversed(range(len(rewards))):
#         delta = rewards[t] + gamma * all_values[t + 1] - all_values[t]
#         gae = delta + gamma * lam * gae
#         advantages[t] = gae

#     advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

#     # Step 4: Prepare dataset
#     dataset = TensorDataset(images, input_ids, targets, old_log_probs, old_values, rewards, advantages)
#     dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

#     # Step 5: PPO optimization loop
#     model.train()
#     for epoch in range(ppo_epochs):
#         for batch in dataloader:
#             batch_images, batch_input_ids, batch_targets, old_log_probs_batch, old_values_batch, rewards_batch, adv_batch = batch

#             logits, values = model(batch_images, batch_input_ids)
#             log_probs = F.log_softmax(logits, dim=-1)
#             new_log_probs = log_probs.gather(2, batch_targets.unsqueeze(-1)).squeeze(-1)

#             # PPO ratio
#             ratio = torch.exp(new_log_probs - old_log_probs_batch)

#             # PPO Clipped Objective
#             clipped_ratio = torch.clamp(ratio, 1 - clip_ratio, 1 + clip_ratio)
#             ppo_objective = torch.min(ratio * adv_batch, clipped_ratio * adv_batch)
#             ppo_loss = -ppo_objective.mean()

#             # KL divergence: D_KL(old || new)
#             kl_div = (old_log_probs_batch.exp() * (old_log_probs_batch - new_log_probs)).mean()

#             # Value loss
#             value_loss = F.mse_loss(values.squeeze(-1), rewards_batch)

#             # Total loss
#             total_loss = ppo_loss + value_loss + kl_coeff * kl_div

#             # Backpropagation
#             optimizer.zero_grad()
#             total_loss.backward()
#             optimizer.step()

#         print(f"[Epoch {epoch + 1}] PPO Loss: {ppo_loss.item():.4f}, "
#               f"Value Loss: {value_loss.item():.4f}, KL: {kl_div.item():.6f}")
