# Gemini pass/maj-at-k experiments

Run Gemini 1.5 Pro v2 with temperature 1 and top_p 0.95, sampling 64 times. Used zero-shot prompt (short answer) and got 8 completions at seeds 0-7 (as each call supports at most 8 completions). These samples are then used to compute pass@k and maj@k results for powers of 2 from 1 to 64.

In [1]:
import sys
from pathlib import Path

BASEDIR = Path("/workspaces/HARP/") / "src"  # Replace with your own basedir path for the repo

sys.path.insert(0, str(BASEDIR))

In [2]:
from __future__ import annotations

import copy
import itertools
import json
import math
import os
import pickle
import pprint
import re
import textwrap
import time
import traceback
from collections import Counter, defaultdict
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tiktoken
from IPython.display import Markdown, clear_output, display
from tqdm.auto import tqdm

import vertexai
from vertexai.batch_prediction._batch_prediction import BatchPredictionJob

In [3]:
from eval.api import safe_unified_api_call
from eval.costs import count_tokens, get_pricing
from eval.eval import run_one, create_batch, make_answer_check_dict_from_jsonl
from eval.parsing_lib import *
from eval.latex_answer_check import *
from eval.prompt import create_prompt
from eval.prompts import *
from eval.utils import read_jsonl, write_jsonl, get_uid, upload_blob, download_blob

# Data

In [4]:
dataset = read_jsonl(BASEDIR / "data/processed/HARP.jsonl")
dataset_map = {get_uid(p): p for p in dataset}
len(dataset)

4780

# Run eval

In [5]:
vertexai.init(project=os.environ.get("VERTEXAI_PROJECT_ID"), location="us-central1")

In [6]:
BUCKET_NAME = os.environ.get("GCLOUD_BUCKET_NAME")  # Should have the form "cloud-ai-platform-<YOUR_BUCKET>"

## Create batch

In [19]:
# We want a total of 64, but the max number of completions in one request is 8
# So let's make 8 runs, with incrementing values for the seed
for seed in range(8):
    batch = create_batch(
        dataset,
        api="google",
        model="gemini-1.5-pro-002",
        fewshot_messages=[],
        system_prompt=gemini_0shot_sysprompt,
        max_tokens=2048,
        num_completions=8,
        temperature=1,
        top_p=0.95,
        seed=seed,
        stop_sequences=["I hope it is correct."],
        # just to remove irrelevant params
        logprobs=None,
    )
    write_jsonl(batch, BASEDIR / f"inputs/short_answer/gemini-1.5-pro-002/batch_passk_seed{seed}.jsonl")

## Upload to cloud

In [None]:
for seed in range(8):
    upload_blob(
        BUCKET_NAME,
        BASEDIR / f"inputs/short_answer/gemini-1.5-pro-002/batch_passk_seed{seed}.jsonl",
        f"prompt_data/short_answer/gemini-1.5-pro-002/batch_passk_seed{seed}.jsonl",
    )

## Run batch job

In [None]:
for seed in range(8):
    BatchPredictionJob.submit(
        source_model="gemini-1.5-pro-002",
        input_dataset=f"gs://{BUCKET_NAME}/prompt_data/short_answer/gemini-1.5-pro-002/batch_passk_seed{seed}.jsonl",
        output_uri_prefix=f"gs://{BUCKET_NAME}/outputs/short_answer/gemini-1.5-pro-002/batch_passk_seed{seed}",
    )

In [None]:
download_blob(
    BUCKET_NAME,
    "outputs/short_answer/gemini-1.5-pro-002/batch_passk_seed0/prediction-model-<TIMESTAMP>/predictions.jsonl",
    BASEDIR / "outputs/short_answer/gemini-1.5-pro-002/outputs_passk_seed0.jsonl",
)
download_blob(
    BUCKET_NAME,
    "outputs/short_answer/gemini-1.5-pro-002/batch_passk_seed1/prediction-model-<TIMESTAMP>/predictions.jsonl",
    BASEDIR / "outputs/short_answer/gemini-1.5-pro-002/outputs_passk_seed1.jsonl",
)
download_blob(
    BUCKET_NAME,
    "outputs/short_answer/gemini-1.5-pro-002/batch_passk_seed2/prediction-model-<TIMESTAMP>/predictions.jsonl",
    BASEDIR / "outputs/short_answer/gemini-1.5-pro-002/outputs_passk_seed2.jsonl",
)
download_blob(
    BUCKET_NAME,
    "outputs/short_answer/gemini-1.5-pro-002/batch_passk_seed3/prediction-model-<TIMESTAMP>/predictions.jsonl",
    BASEDIR / "outputs/short_answer/gemini-1.5-pro-002/outputs_passk_seed3.jsonl",
)
download_blob(
    BUCKET_NAME,
    "outputs/short_answer/gemini-1.5-pro-002/batch_passk_seed4/prediction-model-<TIMESTAMP>/predictions.jsonl",
    BASEDIR / "outputs/short_answer/gemini-1.5-pro-002/outputs_passk_seed4.jsonl",
)
download_blob(
    BUCKET_NAME,
    "outputs/short_answer/gemini-1.5-pro-002/batch_passk_seed5/prediction-model-<TIMESTAMP>/predictions.jsonl",
    BASEDIR / "outputs/short_answer/gemini-1.5-pro-002/outputs_passk_seed5.jsonl",
)
download_blob(
    BUCKET_NAME,
    "outputs/short_answer/gemini-1.5-pro-002/batch_passk_seed6/prediction-model-<TIMESTAMP>/predictions.jsonl",
    BASEDIR / "outputs/short_answer/gemini-1.5-pro-002/outputs_passk_seed6.jsonl",
)
download_blob(
    BUCKET_NAME,
    "outputs/short_answer/gemini-1.5-pro-002/batch_passk_seed7/prediction-model-<TIMESTAMP>/predictions.jsonl",
    BASEDIR / "outputs/short_answer/gemini-1.5-pro-002/outputs_passk_seed7.jsonl",
)