In [5]:
from dotenv import load_dotenv, find_dotenv
import re

from langchain_core.messages import HumanMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnableLambda
from langchain_community.chat_models import BedrockChat

def extract_text_between_markers(text):
    '''Helper function to extract code'''
    start_marker='```python'
    end_marker='```'

    pattern = re.compile(f'{re.escape(start_marker)}(.*?){re.escape(end_marker)}', re.DOTALL)
    matches = pattern.findall(text.content)
    return matches[0]
 

# define language model
model_id = 'anthropic.claude-3-sonnet-20240229-v1:0'
#model_id = 'anthropic.claude-3-haiku-20240307-v1:0'
llm = BedrockChat(model_id=model_id, model_kwargs={'temperature': 0})


CONVERT_SYSTEM_PROMPT = '''
You are a world-class Python programmer and an expert on baseball, with a specialization in data analysis using the pybaseball Python library. 
Your goal is to update a code block in order to resolve an error.

This code block is one step toward accomplishing this task:
###
Plot the cumulative sum of strikeouts thrown by Danny Duffy in the 2018 season. 
###

Here is the code that has been executed successfully so far:
```python

# Import necessary libraries
from pybaseball import playerid_lookup, statcast_pitcher
import matplotlib.pyplot as plt
# 1. Get the 'key_mlbam' ID for Danny Duffy
duffy_id = playerid_lookup('duffy', 'danny')['key_mlbam'].values[0]
```


Here is the next code block that threw an error: 
```python
# 2. Get the pitch-level data for Danny Duffy in 2018
duffy_2018 = statcast_pitcher(duffy_id, 2018, ['date', 'strikeouts']) 
```

Review the error message and rewrite the code block that threw an error to resolve the issue. 

Return all python code between three tick marks like this:

```python
python code goes here
```

'''

generate_prompt_template = ChatPromptTemplate.from_messages([
    ("system", CONVERT_SYSTEM_PROMPT),
    MessagesPlaceholder(variable_name="messages"), 
])

In [6]:
error = "TypeError('strptime() argument 1 must be str, not numpy.int64')"

In [7]:
# Chain
generate_chain = generate_prompt_template | llm | RunnableLambda(extract_text_between_markers)

code_solution = generate_chain.invoke({"messages": [HumanMessage(content=f'Here is the error message:\n\n{error}')]})

In [8]:
print(code_solution)


# 2. Get the pitch-level data for Danny Duffy in 2018
duffy_2018 = statcast_pitcher(duffy_id, start_dt=str(2018), end_dt=str(2018), data_type=['date', 'strikeouts'])

