In [1]:
import os
os.environ["KERAS_BACKEND"] = "jax" # you can also use tensorflow or torch
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00" # avoid memory fragmentation on JAX backend.

import keras
import keras_nlp

import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
tqdm.pandas() # progress bar for pandas

import plotly.graph_objs as go
import plotly.express as px
from IPython.display import display, Markdown

2024-04-03 16:06:10.826803: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-03 16:06:10.826898: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-03 16:06:10.946526: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
class CFG:
    seed = 42
    dataset_path = "/kaggle/input/llm-prompt-recovery"
    preset = "gemma_instruct_2b_en" # name of pretrained Gemma
    sequence_length = 512 # max size of input sequence for training
    batch_size = 1 # size of the input batch in training
    epochs = 1 # number of epochs to train

In [3]:
keras.utils.set_random_seed(CFG.seed)

In [4]:
template = """Instruction:\nBelow, you'll find two texts: `Original Text` and `Rewritten Text`. The `Rewritten Text` has been generated from the `Original Text` using the LLM model Gemma 7b-it. Your task is to carefully analyze the similarities and differences between the two texts to determine the instruction or hint given to the LLM model to rewrite or transform the text.\n
Please consider the following 5 questions while analyzing the texts:\n
1) What linguistic features have changed between the original and rewritten text?(e.g., sentence structure, vocabulary, tone)\t
2) Are there any patterns or repetitions in the rewritten text that suggest specific transformations?\t
3) Does the rewritten text demonstrate a preference for certain word choices or syntactic structures?\t
4) How does the rewritten text differ in length or complexity compared to the original text?\t
5) Are there any contextual clues or hints in the original text that might have influenced the transformation process?\n
By carefully examining these aspects, try to deduce the underlying instruction or hint that guided the LLM model in rewriting the text from `Original Text` to `Rewritten Text`.
\n\nOriginal Text:\n{original_text}
\n\nRewriten Text:\n{rewritten_text}
\n\nResponse:\n{rewrite_prompt}"""

In [5]:
df = pd.read_csv("/kaggle/input/50-rnek-rewrite1-csv/rewrite1.csv")

In [6]:
df["prompt"] = df.progress_apply(lambda row: template.format(original_text=row.original_text,
                                                             rewritten_text=row.rewritten_text,
                                                             rewrite_prompt=row.rewrite_prompt), axis=1)
data = df.prompt.tolist()

  0%|          | 0/50 [00:00<?, ?it/s]

In [7]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset(CFG.preset)
gemma_lm.summary()

Attaching 'config.json' from model 'keras/gemma/keras/gemma_instruct_2b_en/2' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_instruct_2b_en/2' to your Kaggle notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_instruct_2b_en/2' to your Kaggle notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_instruct_2b_en/2' to your Kaggle notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_instruct_2b_en/2' to your Kaggle notebook...
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


In [8]:
x, y, sample_weight = gemma_lm.preprocessor(data[0:2])

In [9]:
# Display the shape of each processed output
for k, v in x.items():
    print(k, ":", v.shape)

token_ids : (2, 8192)
padding_mask : (2, 8192)


In [10]:
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

In [11]:
# Limit the input sequence length to 512 (to control memory usage).
gemma_lm.preprocessor.sequence_length = CFG.sequence_length 

# Compile the model with loss, optimizer, and metric
gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=keras.optimizers.Adam(learning_rate=3e-5),
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

# Train model
gemma_lm.fit(data, epochs=CFG.epochs, batch_size=CFG.batch_size)

[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m58s[0m 741ms/step - loss: 2.5964 - sparse_categorical_accuracy: 0.5188


<keras.src.callbacks.history.History at 0x78357022bc40>

In [12]:
test_df = pd.read_csv("/kaggle/input/llm-prompt-recovery/test.csv")
test_df['original_text'] = test_df['original_text'].fillna("")
test_df['rewritten_text'] = test_df['rewritten_text'].fillna("")
test_df.head()

Unnamed: 0,id,original_text,rewritten_text
0,-1,The competition dataset comprises text passage...,Here is your shanty: (Verse 1) The text is rew...


In [13]:
preds = []
for i in tqdm(range(len(test_df))):
    row = test_df.iloc[i]

    # Generate Prompt using template
    prompt = template.format(
        original_text=row.original_text,
        rewritten_text=row.rewritten_text,
        rewrite_prompt=""
    )

    # Infer
    output = gemma_lm.generate(prompt, max_length=512)
    pred = output.replace(prompt, "") # remove the prompt from output
    
    # Store predictions
    preds.append([row.id, pred])

  0%|          | 0/1 [00:00<?, ?it/s]

In [14]:
sub_df = pd.DataFrame(preds, columns=["id", "rewrite_prompt"])
sub_df['rewrite_prompt'] = sub_df['rewrite_prompt'].fillna("")
sub_df['rewrite_prompt'] = sub_df['rewrite_prompt'].map(lambda x: "Improve the essay" if len(x) == 0 else x)
sub_df.to_csv("submission.csv",index=False)
sub_df.head()

Unnamed: 0,id,rewrite_prompt
0,-1,"The LLM model was given the prompt ""write a sh..."


In [15]:
from tensorflow.keras.models import load_model

In [16]:
#gemma_lm.save('/kaggle/working/model1.h5')