In [None]:
!pip install -q /kaggle/input/keras-lib-dataset/keras_nlp-0.12.1-py3-none-any.whl --no-deps

In [None]:
import os
os.environ["KERAS_BACKEND"] = "jax" # you can also use tensorflow or torch
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00" # avoid memory fragmentation on JAX backend.

import keras
import keras_nlp

import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
tqdm.pandas() # progress bar for pandas

import plotly.graph_objs as go
import plotly.express as px
from IPython.display import display, Markdown

In [None]:
class CFG:
    seed = 42
    dataset_path = "/kaggle/input/ai-mathematical-olympiad-prize"
    preset = "gemma_1.1_instruct_2b_en" # name of pretrained Gemma
    sequence_length = 512 # max size of input sequence for training
    batch_size = 1 # size of the input batch in training
    epochs = 1 # number of epochs to train

In [None]:
keras.utils.set_random_seed(CFG.seed)

In [None]:
df1 = pd.read_csv("/kaggle/input/math-qsa-dataset/train.csv")
df2 = pd.read_csv("/kaggle/input/math-qsa-dataset/test.csv")
df = pd.concat([df1, df2], axis=0)
df.head(2)

In [None]:
def is_integer(text):
    try:
        if int(text) >= 0:
            return True
        else:
            return False
    except ValueError:
        return False
    
df["is_integer"] = df.answer.map(is_integer)
df = df[df.is_integer].reset_index(drop=True)
df.head(2)

In [None]:
template = """Role:\nYou are an advanced AI system with exceptional mathematical reasoning and problem-solving capabilities, specifically designed to solve tricky math problems (whose answer is a non-negative integer) written in LaTeX format from the AI Mathematical Olympiad (AIMO) competition. Your task is to accurately analyze and solve intricate mathematical problems, demonstrating a deep understanding of mathematical concepts and a strong ability to apply logical reasoning strategies.\n\nInstruction:
1. Carefully read and comprehend the problem statement provided in the "Problem" section.
2. In the "Solution" section, provide a solution of the problem with detailed explanation of your logical reasoning process. Keep in mind that answer must be a non-negative integer number.
3. At the end, create a "Answer" section where you will state only the final numerical or algebraic answer, without any additional text or narrative.\n\nProblem:\n{problem}\n\nSolution:\n{solution}"""

In [None]:
df["prompt"] = df.progress_apply(lambda row: template.format(problem=row.problem,
                                                             solution=f"{row.solution}\n\nAnswer:\n{row.answer}"),
                                                             axis=1)
data = df.prompt.tolist()

In [None]:
def colorize_text(text):
    for word, color in zip(["Role", "Instruction", "Problem", "Solution", "Answer"],
                           ["blue", "yellow", "red", "cyan", "green"]):
        text = text.replace(f"{word}:", f"\n\n**<font color='{color}'>{word}:</font>**")
    return text

In [None]:
# Take a random sample
sample = data[12]

# Give colors to Instruction, Response and Category
sample = colorize_text(sample)

# Show sample in markdown
display(Markdown(sample))

In [None]:
# Take a random sample
sample = data[32]

# Give colors to Instruction, Response and Category
sample = colorize_text(sample)

# Show sample in markdown
display(Markdown(sample))

In [None]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset(CFG.preset)
gemma_lm.summary()

In [None]:
x, y, sample_weight = gemma_lm.preprocessor(data[0:2])

In [None]:
# Display the shape of each processed output
for k, v in x.items():
    print(k, ":", v.shape)

In [None]:
# Take one sample
row = df.iloc[12]

# Generate Prompt using template
prompt = template.format(
    problem=row.problem,
    solution="",
)

# Infer
output = gemma_lm.generate(prompt, max_length=1024)

# Colorize
output = colorize_text(output)

# Display in markdown
display(Markdown(output))

In [None]:
# Take one sample
row = df.iloc[32]

# Generate Prompt using template
prompt = template.format(
    problem=row.problem,
    solution=""
)

# Infer
output = gemma_lm.generate(prompt, max_length=1024)

# Colorize
output = colorize_text(output)

# Display in markdown
display(Markdown(output))

In [None]:
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

In [None]:
# Limit the input sequence length to 512 (to control memory usage).
gemma_lm.preprocessor.sequence_length = CFG.sequence_length 

# Compile the model with loss, optimizer, and metric
gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=keras.optimizers.Adam(learning_rate=2e-5),
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

# Train model
gemma_lm.fit(data, epochs=CFG.epochs, batch_size=CFG.batch_size)

In [None]:
# Take one sample
row = df.iloc[12]

# Generate Prompt using template
prompt = template.format(
    problem=row.problem,
    solution=""
)

# Infer
output = gemma_lm.generate(prompt, max_length=1024)

# Colorize
output = colorize_text(output)

# Display in markdown
display(Markdown(output))

In [None]:
# Take one sample
row = df.iloc[32]

# Generate Prompt using template
prompt = template.format(
    problem=row.problem,
    solution=""
)

# Infer
output = gemma_lm.generate(prompt, max_length=1024)

# Colorize
output = colorize_text(output)

# Display in markdown
display(Markdown(output))

In [None]:
import re

# Extract answer from model response
def get_answer(text):
    try:
        answer = re.search(r'Answer:\s*([\s\S]+)', text).group(1).strip()
        answer = answer.replace(",","")
        if is_integer(answer):
            return int(answer)%1000
        else:
            return 0
    except:
        return 0
    
    
def infer(df):
    preds = []
    for i in tqdm(range(len(df))):
        row = df.iloc[i]

        # Generate Prompt using template
        prompt = template.format(
            problem=row.problem,
            solution=""
        )

        # Infer
        output = gemma_lm.generate(prompt, max_length=1024)
        pred = get_answer(output)

        # Store predictions
        preds.append([row.id, pred])
        if "answer" in row:
            preds[-1] += [row.answer]
    return preds

In [None]:
aimo_df = pd.read_csv(f"{CFG.dataset_path}/train.csv")
train_preds = infer(aimo_df)
train_pred_df = pd.DataFrame(train_preds, columns=["id", "prediction", "answer"])
train_pred_df

In [None]:
test_df = pd.read_csv(f"{CFG.dataset_path}/test.csv")
test_preds = infer(test_df)

In [None]:
import aimo
sub_df = pd.DataFrame(test_preds, columns=["id", "answer"])
sub_df['row_id'] = range(len(sub_df))
sub_df = sub_df[["row_id", "id", "answer"]]  # reorder columns
env = aimo.make_env()
env.predict(sub_df)