Copyright 2024 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


In [None]:
# @title ##Install Required Packages and Restart Colab

!pip install google-cloud-aiplatform pybaseball gradio typing-extensions==4.5 pydantic-core==0.42 matplotlib --upgrade

# # Automatically restart kernel after installs so that your environment can access the new packages
import IPython

app = IPython.Application.instance()
app.kernel.do_shutdown(True)

In [None]:
# @title ##Authenticate to Colab with GCP Credentials

from google.colab import auth
auth.authenticate_user()

In [None]:
# @title ##Setup Vertex AI with Gemini Pro
import vertexai
from vertexai.preview import generative_models
from vertexai.preview.generative_models import GenerativeModel, Part

PROJECT_ID = "cloud-llm-preview4"  # @param {type:"string"}
model_name = "gemini-pro" # @param ["gemini-pro"]
LOCATION = "us-central1"  # @param ["us-central1", "us-east1", "us-east4", "us-east5", "us-west1", "us-west2", "us-west3", "us-west4", "us-south1"]
vertexai.init(project=PROJECT_ID, location=LOCATION)

max_output_tokens_val = 8192  # @param {type:"integer"}
temperature_val = 0.8 # @param {type:"number"}
top_k_val = 40 # @param {type:"integer"}
top_p_val = 1 # @param {type:"number"}

# Safety config
safety_config = {
      generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: generative_models.HarmBlockThreshold.BLOCK_NONE,
      generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT: generative_models.HarmBlockThreshold.BLOCK_NONE,
      generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: generative_models.HarmBlockThreshold.BLOCK_NONE,
      generative_models.HarmCategory.HARM_CATEGORY_UNSPECIFIED: generative_models.HarmBlockThreshold.BLOCK_NONE,
      generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH: generative_models.HarmBlockThreshold.BLOCK_NONE,
    }

def generate(prompt,max_output_tokens=max_output_tokens_val,temperature=temperature_val,top_p=top_p_val,top_k=top_k_val,safety_config=safety_config):
  model = GenerativeModel(model_name)
  responses = model.generate_content(
    prompt,
    generation_config={
        "max_output_tokens": max_output_tokens,
        "temperature": temperature,
        "top_p": top_p,
        "top_k": top_k,
    },
    stream=True,
    safety_settings=safety_config,
  )

  text_responses = []
  for response in responses:
      #print(response.text)
      text_responses.append(response.text)
  return "".join(text_responses)

In [None]:
# @title ## Basics: pybaseball Spray Chart

from pybaseball import statcast_batter, spraychart, cache, playerid_lookup

cache.enable()
season = 2023 # @param {type:"integer"}
player_name = "Corey Seager" # @param {type:"string"}
home_team = "BOS" # @param ["HOU", "LAA", "ATL", "BAL", "BOS", "LAD", "OAK", "MIA", "MIL", "MIN", "SFG", "TBR", "CIN", "CLV", "CSW", "STL", "SEA", "TEX", "WSN", "NYY", "NYM", "PHI", "SDP", "CHC", "PIT", "TOR", "DET", "KCR", "ARI"]
stadium ="red_sox" # @param ["angels", "astros", "athletics", "blue_jays", "braves", "brewers", "cardinals", "cubs", "diamondbacks", "dodgers", "generic", "giants", "indians", "mariners", "marlins", "mets", "nationals", "orioles", "padres", "phillies", "pirates", "rangers", "rays", "red_sox", "reds", "rockies", "royals", "tigers", "twins", "white_sox", "yankees"]

first, last = player_name.split(" ")
player_id_df = playerid_lookup(last, first, fuzzy=True)
first_name = player_id_df['name_first'][0]
last_name = player_id_df['name_last'][0]
player_name = first_name + " " + last_name
player_id = player_id_df['key_mlbam'][0]

title = (player_name + " " + str(season) + " Spray Chart @ " + stadium)

data = statcast_batter(str(season) + '-04-01', str(season) + '-10-01', player_id)
sub_data = data[data['home_team'] == home_team]
spraychart(sub_data, stadium, title=title)

In [None]:
# @title ##Test Gemini Pro: Generate Code
code_prompt = "What is the python code to display a bar chart?" # @param {type:"string"}
code_output = generate(code_prompt)
print(code_output)

In [None]:
# @title ##Run Gemini Pro Generated Python
exec(code_output.replace("```", "").replace("python", ""))

In [None]:
# @title ##Example: Routing prompt based upon user input / Data Structures

#user_query = "Show me the spray chart for Mark Mcgwire in 1998 at STL" # @param {type:"string"}"
user_query = "Show me the spray chart for adolis garcia and the spray chart Jose Altuve for 2022 at New York stadium" # @param {type:"string"}"
user_query_2 = "Show me the strike zone for Jose Altuve during 2023 with shohei ohtani pitching" # @param {type:"string"}"
user_query_3 = "Show me the strike zone for Corey Seager during 2023 at TEX stadium with shohei ohtani pitching" # @param {type:"string"}"
user_query_4 = "Show me both the spray chart and the strike zone for Corey Seager during 2023 at TEX stadium with Shohei Ohtani pitching" # @param {type:"string"}"
user_query_5 = "Show me Corey Seager's batting stats for the 2023 season" # @param {type:"string"}"


#trying CO-STAR Prompt framework
global prompt_template
prompt_template = """&CONTEXT&
Analyze USER_QUERY for which functions need to be called. Each function should have it\'s parameters filled out and provided as the output in JSON format.

&OBJECTIVE&
1 - Using the USER_QUERY, determine if one or more functions need to be called
2 - For each function, attempt to extract the parameters keeping in mind the DESCRIPTION.
3 - If a stadium name is provided, convert it to one of the following options: LAA, ATL, BAL, BOS, LAD, OAK, MIA, MIL, MIN, SFG, TBR, CIN, CLV, CSW, STL, SEA, TEX, WSN, NYY, NYM, PHI, SDP, CHC, PIT, TOR, DET, KCR, ARI, or Unknown.
4 - If a team name is provided, convert it to one of the following options:  Options are Diamondbacks, Braves, Orioles, Red Sox, Cubs, White Sox, Reds, Guardians, Rockies, Tigers, Astros, Royals, Angels, Dodgers, Marlins, Brewers, Twins, Mets, Yankees, Athletics, Phillies, Pirates, Padres, Giants, Mariners, Cardinals, Rays, Rangers, Blue Jays, Nationals, or Unknown.
5 - If no data can be matched to any function, output INVALID

&STYLE&
Always use a programing style.

&TONE&
Programming.

&AUDIENCE&
The output will be passed to a python program so it must be syntactically correct and valid.

&RESPONSE&
Output only in JSON format using the provided FUNCTIONS.
If a parameter is unavailable, only use the string "Unknown".
Always use the EXAMPLE to output one or many function calls as a JSON array:

#EXAMPLE#
[
    {
        "function1": {
            "parameters": {
                "name1": "value1",
                "name2": "value2"
            }
        }
    },
    {
        "function2": {
            "parameters": {
                "name1": "value1"
            }
        }
    },
    {
        "function3": {
            "parameters": {
                "name1": "value1",
                "name2": "value2"
            }
        }
    }
]

#FUNCTIONS#
[
    {
        "get_single_player_data": {
            "parameters": {
                "player_name": "Required. A baseball players name.",
                "team_name": "Required. If not explicity provided, respond with Unknown.",
                "pitcher_name": "Required. The pitcher's name, determined by using the word pitched by, pitching, or pitcher.",
                "stadium": "Required. The short name for the baseball stadium. If not explicitly provided, respond with Unknown.",
                "season": "Required. The four digit year (e.g. 1999). If not explicitly provided, respond with Unknown.",
                "chart_type": "Required. Options are spray_chart, plot_strike_zone, plot_bb_profile, plot_baseball_profile, plot_teams, batting_stats, pitching_stats, or Unknown."
            }
        }
    },
    {
        "comparison_request": {
            "parameters": {
                "comparison_type": "Only include if USER_QUERY asks for a comparison request. Options are bar_chart, line_chart, table, or Unknown."
            }
        }
    }
]

&USER_QUERY&
{user_query}

#Remember to use the CONTEXT, follow the OBJECTIVE, and output in the EXAMPLE format.#
"""

prompt = prompt_template.replace("{user_query}",user_query)
prompt_2 = prompt_template.replace("{user_query}",user_query_2)
prompt_3 = prompt_template.replace("{user_query}",user_query_3)
prompt_4 = prompt_template.replace("{user_query}",user_query_4)
prompt_5 = prompt_template.replace("{user_query}",user_query_5)


#print(prompt)
output = generate(prompt)

#output = bytes(output,'unicode_escape').decode('unicode_escape')
print("Prompt 1 output:")
print(output)

output = generate(prompt_2)
print("Prompt 2 output:")
print(output)

output = generate(prompt_3)
print("Prompt 3 output:")
print(output)

output = generate(prompt_4)
print("Prompt 4 output:")
print(output)

output = generate(prompt_5)
print("Prompt 5 output:")
print(output)



In [None]:
# @title ##Example: Asking for a Comparison

user_query = "Show me the spray chart for Shohei Ohtani on the Angels and the strike zone for Jose Altuve on the Astros for 2021 at New York stadium" # @param {type:"string"}"

prompt = prompt_template.replace("{user_query}",user_query)

output = generate(prompt)
print(output)

In [None]:
# @title ##Defining Python Functions / RAGs

import json
functions = json.loads(output.replace("```JSON","").replace("```",""))
print(functions)

from pybaseball import statcast_batter, spraychart, playerid_lookup, statcast_pitcher, statcast, cache
from pybaseball.plotting import plot_strike_zone, plot_bb_profile

cache.enable()

json_object = functions
def process_json_output(json_object):
  for function in json_object:
    function_name = list(function.keys())[0]
    #print(function_name)

    if function_name == "get_single_player_data":
      function_name = "get_single_player_data"
      function = function[function_name]
      #print(function)
      player_name = bytes(function["parameters"]["player_name"], 'unicode_escape').decode('unicode_escape')
      #print(player_name)
      (first, last) = player_name.split(" ")
      player_team_name = function["parameters"]["team_name"]
      if (function["parameters"]["pitcher_name"] != "Unknown") and (function["parameters"]["pitcher_name"] != "None"):
        pitcher_name = bytes(function["parameters"]["pitcher_name"], 'unicode_escape').decode('unicode_escape')
        (pitcher_first, pitcher_last) = pitcher_name.split(" ")
        pitcher_id_df = playerid_lookup(pitcher_last, pitcher_first, fuzzy=True)
        if(len(player_id_df) > 1):
          yield "No exact match to '"+ function["parameters"]["pitcher_name"] + "' found, assuming first player in this list:"
          yield pitcher_id_df.to_markdown(index=False)
        pitcher_first_name = pitcher_id_df['name_first'][0]
        pitcher_last_name = pitcher_id_df['name_last'][0]
        pitcher_id = pitcher_id_df['key_mlbam'][0]
        print("pitcher_id: " + str(pitcher_id))
        pitcher_name = pitcher_first_name + " " + pitcher_last_name
      else:
        pitcher_id = "Unknown"
        pitcher_name = "Unknown"

      #print(player_team_name)
      stadium = function["parameters"]["stadium"]
      if stadium == "Unknown":
        stadium = "TEX"
      if player_team_name == "Unknown":
        player_team_name = "Rangers"
      #print(stadium)
      season = function["parameters"]["season"]
      #print(season)
      player_id_df = playerid_lookup(last, first, fuzzy=True)
      if(len(player_id_df) > 1):
        yield "No exact match to '"+ player_name + "' found, assuming first player in this list:"
        yield player_id_df.to_markdown(index=False)
      print( player_id_df.to_markdown(index=False))
      #print(player_id_df.to_string(index=False))
      first_name = player_id_df['name_first'][0]
      last_name = player_id_df['name_last'][0]
      player_name = first_name + " " + last_name
      #print(first_name, last_name)
      player_id = player_id_df['key_mlbam'][0]
      print("player_id: " + str(player_id))
      print("stadium: " + stadium)

      print("chart_type: " + function["parameters"]["chart_type"])
      if function["parameters"]["chart_type"] == "spray_chart":
        yield "Gathering data: " + player_name + " Spray Chart @ " + stadium + " for " + str(season)
        data = statcast_batter(str(season) + '-04-01', str(season) + '-10-01', player_id)

        if pitcher_id == "Unknown":
          sub_data = data[(data['home_team'] == stadium)]
          yield spraychart(sub_data, team_stadium = player_team_name, title = (player_name + " Spray Chart @ " + stadium + " for " + str(season)), colorby='events')
        else:
          sub_data = data[(data['home_team'] == stadium) & (data["pitcher"] == pitcher_id)]
          yield spraychart(sub_data, team_stadium = player_team_name, title = (player_name + " Spray Chart against " + pitcher_name + " @ " + stadium + " for " + str(season)), colorby='events')

      if function["parameters"]["chart_type"] == "plot_strike_zone":
        yield "Gathering data: " + player_name + " Strikezone @ " + stadium + " for " + str(season)
        all_data = statcast(str(season) + "-04-01", str(season) + "-10-01")

        if pitcher_id == "Unknown":
          yield plot_strike_zone(all_data.loc[(all_data["batter"] == player_id) & (all_data["home_team"] == stadium)], title = (player_name + " Strikezone @ " + stadium + " for " + str(season)), colorby='events')
        else:
          yield plot_strike_zone(all_data.loc[(all_data["batter"] == player_id) & (all_data["home_team"] == stadium) & (all_data["pitcher"] == pitcher_id)], title = (player_name + " Strikezone against " + pitcher_name + " @ " + stadium + " for " + str(season)), colorby='events')

      if function["parameters"]["chart_type"] == "plot_bb_profile" or function["parameters"]["chart_type"] == "plot_baseball_profile":
        yield "Gathering data: " + player_name + " Baseball Profile @ " + stadium + " for " + str(season)
        all_data = statcast(str(season) + "-04-01", str(season) + "-10-01")
        if pitcher_id == "Unknown":
          yield plot_bb_profile(all_data.loc[(all_data["batter"] == player_id) & (all_data["home_team"] == stadium)])
        else:
          yield plot_bb_profile(all_data.loc[(all_data["batter"] == player_id) & (all_data["home_team"] == stadium) & (all_data["pitcher"] == pitcher_id)])

      if function["parameters"]["chart_type"] == "batting_stats":
        yield "Gathering data: " + player_name + " Baseball Stats for " + str(season)
        statcast_stats_df = statcast_batter(str(season) + "-04-01", str(season) + "-10-01", player_id)
        yield statcast_stats_df.to_markdown(index=False)


      if function["parameters"]["chart_type"] == "pitching_stats":
        yield "Gathering data: " + player_name + " Pitching Stats for " + str(season)
        statcast_stats_df = statcast_pitcher(str(season) + "-04-01", str(season) + "-10-01", player_id)
        yield statcast_stats_df.to_markdown(index=False)


      if function["parameters"]["chart_type"] == "batting_stats" or function["parameters"]["chart_type"] == "pitching_stats":
        # print(statcast_stats_df)
        yield statcast_stats_df.to_markdown(index=False)


    elif function_name == "comparison_request":
      function_name = "comparison_request"
      function = function[function_name]
      #print(function)
      comparison_type = function["parameters"]["comparison_type"]
      print(comparison_type)
  #final return from function
  return

for value in process_json_output(functions):
  value
  print(str(type(value)))

In [None]:
# @title ## Get Avatar for Chatbot

!pip install --upgrade CairoSVG

from PIL import Image
import cairosvg

cairosvg.svg2png(url="https://baseballsavant.mlb.com/site-core/images/savant-logo.svg", write_to="avatar.png")
display(Image.open("avatar.png"))

In [None]:
# @title ## Setup Temp dir for images
!rm -rf /content/chatbot_images || true
!mkdir -p /content/chatbot_images
images_dir = "/content/chatbot_images/"

In [None]:
# @title ## Initialize & Run Gradio

import gradio as gr
import pandas as pd
import time, os, uuid

def print_like_dislike(x: gr.LikeData):
    print(x.index, x.value, x.liked)


def add_text(history, text):
    history = history + [(text, None)]
    return history, gr.Textbox(value="", interactive=False)


def bot(history):
    # to add a bot msg response - it goes into the second tuple (user is first) in a history element
    # media can be added by providing a filename
    #history[-1][1] = ("avatar.png",)
    user_query = history[-1][0]
    print(user_query)
    prompt = prompt_template.replace("{user_query}",user_query)
    print("---")
    print(prompt)
    print("---")
    output = generate(prompt)
    print(output)
    functions = json.loads(output.replace("```JSON","").replace("```",""))
    for value in process_json_output(functions):
      print(str(type(value)))
      if str(type(value)) == "<class 'str'>":
        history = history + [(None, value)]
      elif str(type(value)) == "<class 'pandas.core.frame.DataFrame'>":
        history = history + [(None, str(value))]
      elif str(type(value)) == "<class 'matplotlib.axes._axes.Axes'>":
        value_filename = images_dir+uuid.uuid4().hex+".png"
        value.get_figure().savefig(value_filename)
        history = history + [(None, (value_filename,))]
      yield history
    return history

css = """
.container {
    height: 160vh;
}
"""

with gr.Blocks(css=css) as demo:
  with gr.Column():
    chatbot = gr.Chatbot(
        [],
        #elem_classes=["container"],
        height=500,
        elem_id="chatbot",
        bubble_full_width=False,
        avatar_images=(None, (os.path.join(os.path.dirname("."), "avatar.png"))),
        )

    with gr.Row():
        txt = gr.Textbox(
            #scale=1,
            show_label=False,
            placeholder="Enter text and press enter, or upload an image",
            container=False,
        )

  txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
      bot, chatbot, chatbot, api_name="bot_response"
  )
  txt_msg.then(lambda:
               gr.Textbox(interactive=True), None, [txt], queue=False
               )

  chatbot.like(print_like_dislike, None, None)


demo.queue()
demo.launch(debug=True, share=True)