In [None]:
import pickle, json, os, time, re
import traceback
import google.generativeai as genai
import cred
import pandas as pd

In [2]:
# Set up Gemini API key
## TODO: Fill in your Gemini API in the ""
genai.configure(api_key=cred.keys['GEMINI_API_KEY'])
model = genai.GenerativeModel(model_name="gemini-1.5-flash")  # No need to prefix with "models/"

# Check if you have set your Gemini API successfully
# You should see "Set Gemini API sucessfully!!" if nothing goes wrong.
try:
    model.generate_content(
      "test",
    )
    print("Set Gemini API sucessfully!!")
except:
    print("There seems to be something wrong with your Gemini API. Please follow our demonstration in the slide to get a correct one.")

Set Gemini API sucessfully!!


In [3]:
class GeminiModel():
    def __init__(self, cache_file="gemini_cache"):
      # Constructor to initialize the OpenAIModel object.
      # cache_file: Name of the file used for caching data.
      self.cache_file = cache_file
      # Load the cache from the file.
      self.cache_dict = self.load_cache()

      ## count token
      #tokens = my_model.model.count_tokens("Hello world")
      #token = tokens.total_token

      # init model
      safety_settings = [
          {
              "category": "HARM_CATEGORY_DANGEROUS",
              "threshold": "BLOCK_NONE",
          },
          {
              "category": "HARM_CATEGORY_HARASSMENT",
              "threshold": "BLOCK_NONE",
          },
          {
              "category": "HARM_CATEGORY_HATE_SPEECH",
              "threshold": "BLOCK_NONE",
          },
          {
              "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
              "threshold": "BLOCK_NONE",
          },
          {
              "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
              "threshold": "BLOCK_NONE",
          },
      ]
      self.model = genai.GenerativeModel('gemini-1.5-flash', safety_settings=safety_settings)

    def save_cache(self):
      # Save the current state of the cache to the cache file.
        with open(self.cache_file, "wb") as f:
            pickle.dump(self.cache_dict, f)

    def load_cache(self, allow_retry=True):
        # Load the cache from the cache file, with retry logic for handling read errors.
        if os.path.exists(self.cache_file):
            while True:
                try:
                    with open(self.cache_file, "rb") as f:
                        cache = pickle.load(f)
                    break
                except Exception:
                    if not allow_retry:
                        assert False
                    print ("Pickle Error: Retry in 5sec...")
                    time.sleep(5)
        else:
            # Initialize the cache if the file does not exist.
            cache = {}
        return cache

    def set_cache_file(self, file_name):
      self.cache_file = file_name
      # Load the cache from the file.
      self.cache_dict = self.load_cache()

    def get_completion(self, content):
        # Get a completion for a given content, optionally using the extraction tool.
        # Checks cache before making a new request to the model.
        if content in self.cache_dict:
            return self.cache_dict[content]
        for _ in range(3):
            try:
              # Make a request to the Gemini model.

              response = self.model.generate_content(
                      content,
                      generation_config=genai.types.GenerationConfig(temperature=1.0),
                      request_options={"timeout": 120}
                  )
              # Store the completion
              completion = response.text
              self.cache_dict[content] = completion
              return completion
            except Exception as e:
                # traceback.print_exc()
                print(e, "\n")
                time.sleep(1)
        return None

    ## need double check
    def is_valid_key(self):
        for _ in range(4):
            try:
                # ref: https://ai.google.dev/tutorials/python_quickstart
                response = self.model.generate_content(
                        "hi there",
                        generation_config=genai.types.GenerationConfig(temperature=1.0)
                    )
                return True
            except Exception as e:
                traceback.print_exc()
                time.sleep(1)
        return False

    def prompt_token_num(self, prompt):
        tokens = self.model.count_tokens(prompt)
        token = tokens.total_tokens

        ## DEBUG
        #print(f"The token num of \'{prompt}\' is {token}")

        return token

    def two_stage_completion(self, question, content):
        # A two-stage completion process: first to get the rationale and then the final answer.
        rationale = self.get_completion(content)
        if not rationale:
            return {
            'prompt': content,
            'rationale': None,
            'answer': None
            }

        ans = self.get_completion(content = f"Q:{question}\nA:{rationale}\nThe answer to the original question is (a number only): ")
        return {
            'prompt': content,
            'rationale': rationale,
            'answer': ans
        }

my_model = GeminiModel()

In [10]:
question = pd.read_csv("HW4 Math Question.csv")

In [7]:
reply = my_model.two_stage_completion('Q:What is the sum of 1 and 2? Ans: {{answer}}', 'Use your instinction')

In [20]:
question.head()

Unnamed: 0,Question,Ground Truth
0,An artist is creating a large mosaic using squ...,228.0
1,A farmer is filling baskets with apples for a ...,15.0
2,A garden has rectangular plots arranged in suc...,132.0
3,A farmer's market sells apples in two types of...,0.45
4,A gardener is planting flowers in a pattern: t...,24.0


In [24]:
ans = []
for i in range(3):
    reply = my_model.get_completion(question['Question'][i])
    matches = re.findall(r'\{(.*?)\}', reply)
    ans.append(matches[0])


In [29]:
for i in range(3):
    print(f"Real answer: {question['Ground Truth'][i]}")
    print(f"Answer from AI: {ans[i]}")
    if question['Ground Truth'][i] == ans[i]:
        print("Correct!")
    else:
        print("Wrong!")
    print("\n")

Real answer: 228.0
Answer from AI: 256
Wrong!


Real answer: 15.0
Answer from AI: 15
Wrong!


Real answer: 132.0
Answer from AI: 132
Wrong!


