In [17]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [18]:
from dotenv import load_dotenv
from pathlib import Path

dotenv_path = Path("/home/ubuntu/config.env")

load_dotenv(dotenv_path=dotenv_path)

True

### construct prompt

In [76]:
sytem_tone_context = """You are an expert at writing syntactically correct queries that run on AWS Cloudwatch logs groups.
The query syntax supports different functions, operations required to search and analyze logs."""

context = """<context>
<logs_and_fields>
1. Sytem fields: @message, @timestamp, @ingestionTime, @logStream, @log
2. Amazon VPC flow logs: @timestamp, @logStream, @message, accountId, endTime, interfaceId, logStatus, startTime, version, action, bytes, dstAddr, dstPort, packets, protocol, srcAddr, srcPort
3. Route 53 logs: @timestamp, @logStream, @message, edgeLocation, ednsClientSubnet, hostZoneId, protocol, queryName, queryTimestamp, queryType, resolverIp, responseCode, version
4. Lambda logs: @timestamp, @logStream, @message, @requestId, @duration, @billedDuration, @type, @maxMemoryUsed, @memorySize, @xrayTraceId, @xraySegmentId
</logs_and_fields>

<fields_definition>
1. @message: contains the raw unparsed log event.
2. @timestamp: contains the event timestamp in the log event's timestamp field.
3. @ingestionTime: contains the time when CloudWatch Logs received the log event.
4. @logStream: contains the name of the log stream that the log event was added to. Log streams group logs through the same process that generated them.
5. @log: is a log group identifier in the form of account-id:log-group-name. When querying multiple log groups, this can be useful to identify which log group a particular event belongs to.
</fields_definition>

<commands>
1. display: displays a specific field or fields in query results.
2. fields: displays specific fields in query results and supports functions and operations you can use to modify field values and create new fields to use in your query.
3. filter: filters the query to return only the log events that match one or more conditions.
4. pattern: automatically clusters your log data into patterns. A pattern is shared text structure that recurs among your log fields. CloudWatch Logs Insights provides ways for you to analyze the patterns found in your log events. For more information, see Pattern analysis.
5. diff: compares the log events found in your requested time period with the log events from a previous time period of equal length, so that you can look for trends and find out if certain log events are new.
6. parse: extracts data from a log field to create an extracted field that you can process in your query. parse supports both glob mode using wildcards, and regular expressions.
7. sort: displays the returned log events in ascending (asc) or descending (desc) order.
8. stats: calculate aggregate statistics using values in the log fields.
9. limit: specifies a maximum number of log events that you want your query to return. Useful with sort to return "top 20" or "most recent 20" results.
10. dedup: removes duplicate results based on specific values in fields that you specify.
11. unmask: displays all the content of a log event that has some content masked because of a data protection policy. For more information about data protection in log groups, see Help protect sensitive log data with masking.
</commands>

<operators>
1. arithmetic: `+` | `-` | `*` | `/` | `^` | `%`
2. boolean: `and` | `or` | `not`
3. comparison: `=` | `!=` | `<` | `>` | `<=` | '>=`
4. numeric: `abs` | `ceil` | `floor` | `greatest` | `least` | `log` | `sqrt`
5. datetime: `bin` | `datefloor` | `dateceil` | `fromMillis` | `toMillis`
</operators>

<functions>
1. General: 
    - ispresent(fieldName: LogField)
    - coalesce(fieldName: LogField, ...fieldNames: LogField[])
2. IP address string functions: 
    - isValidIp(fieldName: string)
    - isValidIpV4(fieldName: string)
    - isValidIpV6(fieldName: string)
    - isIpInSubnet(fieldName: string, subnet: string)
    - isIpv4InSubnet(fieldName: string, subnet: string)
    - isIpv6InSubnet(fieldName: string, subnet: string)
3. String funcitons:
    - isempty
    - isblank
    - concat
    - trim
    - strlen
    - toupper
    - tolower
    - substr
    - replace
    - strcontains
</functions>
</context>"""

# TODO: add better & clear rules
rules = """<rules>
1. Make sure that query only uses fields, comparators, operators and functions listed above and no others.
2. Use a pipe character (|) to separate multiple commands.
3. Make sure query is syntactically correct. If there's no query then return "NO_QUERY"
</rules>"""

output_format_instructions = """<output_format_instructions>
When responding use a markdown code snippet with a JSON object formatted in the following schema:

```json
{{
    \"query\": string \ this runs on cloudwatch logs
}}
```

Do not explain. Only respond with the JSON object. 
</output_format_instructions>
"""

system_template = "\n\n".join([
    sytem_tone_context, context, rules, output_format_instructions
])
# template = f"\n\nHuman: {template}\nAssistant: "

# print(system_template)

In [171]:
from langchain_community.vectorstores import FAISS, Chroma
from langchain_community.embeddings import CohereEmbeddings
from langchain.prompts import (
    ChatPromptTemplate, 
    FewShotChatMessagePromptTemplate,
    SemanticSimilarityExampleSelector,
    HumanMessagePromptTemplate,
    AIMessagePromptTemplate,
)
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
from data import cwqueries

examples = cwqueries.cw_data
texts = ["\n---\n".join(example.values()) for example in examples]
# print(texts[0])
vectorstore = FAISS.from_texts(
    texts=texts, embedding=CohereEmbeddings(), metadatas=examples
)

example_selector = SemanticSimilarityExampleSelector(
    vectorstore=vectorstore, k=3
)

human_template = """Text: {text}"""

# ai_template = '''```json
# {{
#     \"query\": """{query}"""
# }}
# ```'''

ai_template = '''```json
{{
    \"query\": "{query}"
}}
```'''
few_shot_examples_prompt = FewShotChatMessagePromptTemplate(
    input_variables=["text"],
    example_selector=example_selector,
    example_prompt=ChatPromptTemplate.from_messages([
        # HumanMessagePromptTemplate.from_template(human_template),
        # AIMessagePromptTemplate.from_template(ai_template)m
        ChatMessagePromptTemplate.from_template(
            role="H", template=human_template
        ),
        ChatMessagePromptTemplate.from_template(
            role="A", template=ai_template
        )
    ]),
)

print(few_shot_examples_prompt.format(text="latest logs with errors"))

H: Text: Show all logs where the values forloggingType are ERROR
A: ```json
{
    "query": "fields @message
| parse @message "[*] *" as loggingType, loggingMessage
| filter loggingType = "ERROR"
| display loggingMessage"
}
```
H: Text: show the latest 20 log events
A: ```json
{
    "query": "fields @timestamp, @message
| sort @timestamp desc
| limit 20"
}
```
H: Text: show latest 20 logs with range greater than 3000
A: ```json
{
    "query": "fields @timestamp, @message
| filter (range>3000)
| sort @timestamp desc
| limit 20"
}
```


In [172]:
human_template = """Generate query for below question.
Text: {text}"""

In [173]:
# TODO: find a way to wrap few_shot_examples_prompt with <examples></examples> tag
prompt = ChatPromptTemplate.from_messages([
    ("system", system_template),
    few_shot_examples_prompt,
    ("human", human_template)
])

# print(prompt.format(text="latest logs with errors"))

In [174]:
import boto3, json
from langchain_community.chat_models import BedrockChat

llm = BedrockChat(
    model_id="anthropic.claude-3-haiku-20240307-v1:0",
    # model_id="anthropic.claude-3-sonnet-20240229-v1:0",
    client=boto3.client("bedrock-runtime"),
    model_kwargs={"temperature": 0.0, "max_tokens":512}
)

# ! pip install multiline --quiet
# import multiline
# multiline.loads(res, multiline=True)

def _sanitize_output(ai_message: str):
    text = ai_message.content.strip()
    _, after = text.split("```json")
    res = after.split("```")[0]
    res_json = json.loads(res, strict=False)
    return res_json

In [184]:
%%time
question = "show the latest logs that contains ERROR"

response = llm.invoke(prompt.format(text=question))
print(f"actual_response: {response.content}")
print("-"*25)
final_response = _sanitize_output(response)
print(f"after_sanitization: {final_response}")

actual_response: ```json
{
    "query": "fields @timestamp, @message
| filter @message like /ERROR/
| sort @timestamp desc
| limit 20"
}
```
-------------------------
after_sanitization: {'query': 'fields @timestamp, @message\n| filter @message like /ERROR/\n| sort @timestamp desc\n| limit 20'}
CPU times: user 18.3 ms, sys: 3.43 ms, total: 21.8 ms
Wall time: 1.16 s
