In [1]:
import torch
from transformers import BlipProcessor, BlipForConditionalGeneration, AutoModelForSequenceClassification, AutoTokenizer
from trl import PPOTrainer, PPOConfig
from transformers import pipeline

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = BlipProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b").to(device)

reward_model_name = "facebook/opt-2.7b"  # Change to a smaller model if needed
reward_tokenizer = AutoTokenizer.from_pretrained(reward_model_name)
reward_model = AutoModelForSequenceClassification.from_pretrained(reward_model_name, num_labels=1).to(device)

In [None]:
# Dataset: Load in dataset
# To be completed

def load_image(image_path):
    pass

dataset = None

In [None]:
def compute_reward(text_explanation):
    inputs = reward_tokenizer(text_explanation, return_tensors="pt", padding=True, truncation=True).to(device)
    with torch.no_grad():
        reward = reward_model(**inputs).logits.squeeze()
    return reward.item()

In [None]:
config = PPOConfig(
    model_name="Salesforce/blip2-opt-2.7b",
    learning_rate=1e-5,
    batch_size=4,
    log_with="wandb",  # Set up Weights & Biases for logging
)

In [None]:
ppo_trainer = PPOTrainer(
    model=model,
    config=config,
    dataset=dataset["train"],
    tokenizer=processor.tokenizer,
)

In [None]:
for epoch in range(3):  # Train for 3 epochs
    for batch in dataset["train"]:
        query_images = batch["pixel_values"]  # ARC input images

        # Generate explanations
        response = model.generate(input_ids=query_images, max_length=50)

        # Compute rewards
        reward = compute_reward(processor.tokenizer.decode(response[0]))

        # Train PPO
        ppo_trainer.step(query_images, response, reward)

# Save the trained PPO model
model.save_pretrained("./arc_blip2_ppo")
processor.save_pretrained("./arc_blip2_ppo")

In [None]:
query = """This image is a logic puzzle that contains input and output examples. Your job is to learn
                the transformation between the input and output grids and apply it to the final grid, which has no associated output (appears as a black box).
                In addition to the image, I will give you an array representation for all input and output grids seen. This grid is
                a rectangular matrix (list of lists) of integers between 0 and 9 (inclusive). The smallest possible grid size is 
                1x1 and the largest is 30x30.

                Your goal is to construct the output grid corresponding to the test input grid. 
                "Constructing the output grid" involves picking the height and width of the output grid, then filling each 
                cell in the grid with a symbol (integer between 0 and 9, which are visualized as colors). Only exact solutions 
                (all cells match the expected answer) can be said to be correct. Please output an array representing the output grid.
                If you are able to generate images, output an image visualization of the grid."""

In [None]:
def solve_arc_task(image_path):
    image = load_image(image_path)
    inputs = processor(images=image, text="What is the transformation?", return_tensors="pt").to(device)

    # Generate explanation
    output = model.generate(input_ids=inputs.pixel_values, max_length=50)
    explanation = processor.tokenizer.decode(output[0])

    return explanation

In [None]:
print(solve_arc_task("path/to/new_arc_problem.png"))