In [None]:
# train_clip_rl.ipynb

import torch
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from transformers import CLIPProcessor, CLIPModel

# Load the pre-trained model and processor
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_model.load_state_dict(torch.load("clip_finetuned.pt")['model_state_dict'])
clip_model.to(device)
clip_model.train()

# Reduce dataset size for RL fine-tuning
rl_train_size = int(0.4 * len(train_dataset))  # Adjust train dataset size
rl_train_dataset, _ = random_split(train_dataset, [rl_train_size, len(train_dataset) - rl_train_size])
rl_train_dataloader = DataLoader(rl_train_dataset, batch_size=6, shuffle=True, collate_fn=preprocess_data)

# Improved reward function for RL
def reward_function(predictions, ground_truth):
    correct = (predictions == ground_truth).float()
    confidence = torch.softmax(logits, dim=1).max(dim=1).values
    return torch.mean(correct * confidence)

# RL fine-tuning loop
optimizer = torch.optim.Adam(clip_model.parameters(), lr=1e-5)
num_rl_epochs = 3

for epoch in range(num_rl_epochs):
    total_reward = 0
    print(f"Starting RL epoch {epoch + 1}")
    with tqdm(total=len(rl_train_dataloader), desc=f"RL Epoch {epoch + 1}", unit="batch") as pbar:
        for step, batch in enumerate(rl_train_dataloader):
            optimizer.zero_grad()
            outputs = clip_model(**batch)
            logits = outputs.logits_per_image
            labels = torch.arange(len(logits)).to(device)
            predictions = torch.argmax(logits, dim=1)
            rewards = reward_function(predictions, labels)
            (-rewards).backward()  # Minimize negative rewards
            optimizer.step()
            total_reward += rewards.item()

            pbar.set_postfix(reward=f"{rewards.item():.4f}")
            pbar.update(1)

    print(f"RL Epoch {epoch + 1}: Average Reward = {total_reward / len(rl_train_dataloader):.4f}")

# Save RL fine-tuned model
clip_model.save_pretrained("clip_rl_finetuned")
clip_processor.save_pretrained("clip_rl_finetuned_processor")

print("RL Fine-tuning completed and model saved.")
