# 📝 Exploring Gemma_2b_en in NER Tasks

🚩In this Notebook, I aim to showcase and explore the performance of <span style='color: brown;'><b>gemma_2b_en model</b></span> in NER (Named Entity Recognition) tasks. The learning paradigms utilized in this Notebook are <span style='color: brown;'><b>Zero Shot Learning</b></span>, <span style='color: brown;'><b>One Shot Learning</b></span>, <span style='color: brown;'><b>Few Shot Learning</b></span> (contains 3, 5, 7, 9 Shots). 


📃The Input-Dataset <span style='color: orange;'><b>"pii-combined-data"</b></span> is from [The Learning Agency Lab - PII Data Detection Competition](https://www.kaggle.com/competitions/pii-detection-removal-from-educational-data/data) as well as additional external data [here](https://www.kaggle.com/code/valentinwerner/fix-punctuation-tokenization-external-dataset). The competition dataset contains about 22,000 essays from students in an online course. Task is to <span style='color: brown;'><b>find and label personally identifiable information (PII)</b></span> in these essays.

📃 The Input-Dataset <span style='color: orange;'><b>"temp-test-gemma-output"</b></span> is utilized to prevent re-running the learning phase because executing all learning paradigms would be time-consuming (more than 3 hours). With this dataset, I've saved all the outputs in a dictionary, allowing us to save learning time and conduct further analysis.


<blockquote>    
<p style="color:green; font-weight:bold;">PII Types: </p>
    
<ol>
<li>
<u style="color: green; font-weight: bold;font-size: 12px">NAME_STUDENT</u>: Student names (excluding instructors, authors, etc.).
</li>     
<li>
<u style="color: green; font-weight: bold;font-size: 12px">EMAIL</u>: Student email addresses.
</li>
<li>
<u style="color: green; font-weight: bold;font-size: 12px">USERNAME</u>: Student usernames on any platform.
</li>
<li>
<u style="color: green; font-weight: bold;font-size: 12px">ID_NUM</u>: Student identification numbers (e.g., student ID, social security number).
</li>
<li>
<u style="color: green; font-weight: bold;font-size: 12px">PHONE_NUM</u>: Student phone numbers.
</li>
<li>
<u style="color: green; font-weight: bold;font-size: 12px">URL_PERSONAL</u>: URLs that could identify a student.
</li>
<li>
<u style="color: green; font-weight: bold;font-size: 12px">STREET_ADDRESS</u>: Student street addresses (e.g., home address).
</li>
<li>
<u style="color: green; font-weight: bold;font-size: 12px">STREET_ADDRESS</u>: Student street addresses (e.g., home address).
</li>
</ol>
</blockquote> 
    
    
##  <span style='color: Orange;'><b>Overall Sections</b></span>

#### 1. **Load and Analyze Data** 📊

<blockquote>
💡In the first part, we will complete the following steps: <u style="color: brown; font-style: italic;">loading</u>, <u style="color: brown; font-style: italic;">preprocessing</u>, <u style="color: brown; font-style: italic;">splitting</u> data into training and testing datasets, and <u style="color: brown; font-style: italic;">visualizing label distribution</u> using Plotly.
</blockquote>

#### 2. **Prepare Prompt for Gemma** 🎨
<blockquote>
<p>
💡In the second part, we will prepare the <u style="color: brown; font-style: italic;">Answer dictionary</u> for each example，and then <u style="color: brown; font-style: italic;">generate prompts</u> for different learning paradigms. 
</p>
    
<p style="font-weight:bold;">Strategies: </p>
    
<ol>
    <li>
        For <u style="color: brown; font-weight: bold;">Zero Shot Learning</u>, only include context information to inform the model of the task.
    </li>
    <li>
        For <u style="color: brown; font-weight: bold;">other learning paradigms</u>, include context information along with N examples.
    </li>
</ol>

<p style="font-weight:bold;">Here is an example of Answer dictionary (Synthetic Data): </p>
    
{<span style='font-size: 12px; color: green;'><b>'EMAIL'</b></span>: ['scottsherman@yahoo.com', 'larsenjoseph@gmail.com'],

 <span style='font-size: 12px; color: green;'><b>'ID_NUM'</b></span>: ['AGX-6811'],

 <span style='font-size: 12px; color: green;'><b>'NAME_STUDENT'</b></span>: ['Stephenxxx Morgan'],

 <span style='font-size: 12px; color: green;'><b>'PHONE_NUM'</b></span>: ['(485)728-5578', '(261)318-0141'],

 <span style='font-size: 12px; color: green;'><b>'STREET_ADDRESS'</b></span>: ['70 Harrison Manor Suite 01, Franklinville, VT 78297'],

 <span style='font-size: 12px; color: green;'><b>'URL_PERSONAL'</b></span>: ['https://github.com/123'],

 <span style='font-size: 12px; color: green;'><b>'USERNAME'</b></span>: ['nathanby']}

    
</blockquote>

#### 3. **Load Gemma Model** 🤖
<blockquote>
💡In the third part, we will import the Gemma_2b_en model from Kaggle and set <u style="color: brown; font-style: italic;">TopKSampler</u> with a seed for text generation.
</blockquote>

#### 4. **N Shot Learning Phase** 📚

#### 5. **Post-process of model's outputs** 🛠️
<blockquote>
💡In the last part, we will post-process the model's output results by <u style="color: brown; font-style: italic;">extracting the real answers</u> and then convert them into a usable <u style="color: brown; font-style: italic;">dictionary format</u> to calculate the performance: <u style="color: brown; font-style: italic;">F5 score</u>. (This score is used by the competition, means "Recall being 5 times more important than Precision".) 
</blockquote>

#### **References** 📖
<blockquote>
    
[piidd-let-s-go-higher](https://www.kaggle.com/code/hyunsoolee1010/piidd-let-s-go-higher)
    
[rule-based-approach](https://www.kaggle.com/code/emiz6413/rule-based-approach)
    
[kaggle-qa-with-gemma-kerasnlp-starter](https://www.kaggle.com/code/awsaf49/kaggle-qa-with-gemma-kerasnlp-starter)
</blockquote>

#  🔧<span style='color: Orange;'><b>0. Import necessary libraries, Set-up Configs and Seed</b></span>

In [None]:
import os
os.environ["KERAS_BACKEND"] = "jax" # you can also use "jax" tensorflow or torch
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00" # avoid memory fragmentation on JAX backend.

import time
import random
import json
import numpy as np
import pandas as pd

# pandas progress bar and learning progress bar
from tqdm.notebook import tqdm
tqdm.pandas() 

# We will use keras_nlp for Gemma model
import keras
import keras_nlp

# We will use spacy for tokenizer and re for Regex during post-process of predictions
from spacy.lang.en import English
import re

# Import the libiaries for visualisation
from IPython.display import display, Markdown
import plotly.graph_objects as go
from plotly.subplots import make_subplots

In [None]:
def set_seed(seed):
    os.environ['PYTHONHASHSEED']=str(seed)
    random.seed(seed)
    np.random.seed(seed)
    keras.utils.set_random_seed(seed)

# Set seed for reproducibility   
seed = 42 
set_seed(seed)

# Learning Configs
model_name = "gemma_2b_en" # name of pretrained Gemma
train_path = "/kaggle/input/pii-combined-data/train.json"
test_path = "/kaggle/input/pii-combined-data/moredata_dataset_fixed.json"

# Set the maximum prediction tokens length
max_length_pred_zeroshot = 1024*1
max_length_pred_oneshot = 1024*2
max_length_pred_3shot = 1024*4
max_length_pred_5shot = 1024*6
max_length_pred_7shot = 1024*7
max_length_pred_9shot = 1024*8

#  📈<span style='color: Orange;'><b>1. Load and Quick Analyse Data</b></span>
<blockquote>    
<p style="color:brown; font-weight:bold;">Introduction to the dataset: </p>
We have 5 colomns in train.json and in moredata_dataset_fixed.json. 
    
Here I use the data include labels as test data (but not test.json) so that I can calculate the performance.
    
<ol>
<li>
<u style="color: green; font-weight: bold;font-size: 12px">document (int)</u>: an integer ID of the essay.
</li>
<li>
<u style="color: green; font-weight: bold;font-size: 12px">full_text (string)</u>: a UTF-8 representation of the essay.
</li>     
<li>
<u style="color: green; font-weight: bold;font-size: 12px">tokens (list of string)</u>: list of string representations of each token.
</li>
<li>
<u style="color: green; font-weight: bold;font-size: 12px">trailing_whitespace (list of boolean)</u>: list of boolean values indicating whether each token is followed by whitespace.
</li>
<li>
<u style="color: green; font-weight: bold;font-size: 12px">labels (list of string)</u>: list of token labels in BIO format.
</li>
</ol>
</blockquote> 


In [None]:
# Check positive samples (contain relevant labels) and downsample negative samples (contain only "OTHER")
# Downsample to fix dataset imbalance.
def get_pos_neg(data):
    pos=[] # positive samples
    neg=[] # negative samples
    for d in data:
        if any(np.array(d["labels"]) != "O"): 
            pos.append(d)
        else: 
            neg.append(d)
    # downsample the negative examples
    if len(neg)>len(pos):
        neg = neg[:len(pos)//2]
        print("Downsampled the negative examples.")
    return pos,neg

def transform_to_pd(json_data):
    return pd.DataFrame({
                        "full_text": [x["full_text"] for x in json_data],
                        "document": [str(x["document"]) for x in json_data],
                        "tokens": [x["tokens"] for x in json_data],
                        "trailing_whitespace": [x["trailing_whitespace"] for x in json_data],
                        "labels": [x["labels"] for x in json_data],})

# Load train, test json data
data1 = json.load(open(train_path))
data2 = json.load(open(test_path))

p1, n1 = get_pos_neg(data1)
p2, n2 = get_pos_neg(data2)

# The number of test set > the number of training set, transfer 95% the test data to the training set.
percent= 0.95
length = int(len(p2)*percent)
p2_train = p2[:length]
p2_test = p2[length:]

# Make train dataframe and test dataframe
train_data = transform_to_pd(p1+n1+p2_train)
test_data = transform_to_pd(p2_test+n2)

# ================Logs================
print("Original Train datapoints: ", len(data1))
print("Train datapoints: ", len(train_data))
print("Train Positive datapoints: ", len(p1)+len(p2_train))
print("Train Negative datapoints: ", len(n1))
print("\n")

print("Original Test datapoints: ", len(data2))
print("Test datapoints: ", len(test_data))
print("Test Positive datapoints: ", len(p2_test))
print("Test Negative datapoints: ", len(n2))
print("\n")

display(train_data.head(3))
print('\n')
display(test_data.head(3))
print('\n')

##  📈<span style='color: Orange;'><b>1.1 Pre-process DataFrames</b></span>

👁️‍In the dataframe, I noticed that the <span style="color:brown; font-weight:bold;">"labels" column follows the BIO labeling scheme</span>. (List of BIO labels for each token.) 

<blockquote> 
Specifically, <span style="color:green; font-weight:bold;">"B-"</span> denotes the beginning of a label, and the remaining tokens are marked with <span style="color:green; font-weight:bold;">"I-"</span>. 
</blockquote> 

💡However, this kind of "labels" doesn't provide a clear understanding of how many labels are present in each example or what each label represents. Therefore, in this section, I utilized <span style="color:brown; font-weight:bold;">One-Hot Encoding</span> and some <span style="color:brown; font-weight:bold;">Text Processing</span> to add individual columns for each label name in the dataframe. 

<blockquote>
The values in these One-Hot columns are either 0 or 1, with 0 indicating the absence of the label and 1 indicating its presence.
</blockquote> 

In [None]:
# Let's check out the multi-labels' presences in each example (One-Hot encoding labels)
def get_label_names(x):
    labels = set()
    for ele in x:
        if ele != "O":
            # element is like: B-STREET_ADDRESS
            # remove "B-" or "I-"
            e = ele.split('-')[-1]
            labels.add(e)
    return list(labels)

def onehot_labels(df, str_):
    # get numbers of tokens
    df["token_nums"] = df["tokens"].apply(lambda x: len(x))
    
    # get all and unique labels and total count of unique labels of each example
    df["unique_labels"] = df["labels"].apply(lambda x: get_label_names(x))
    df["unique_labels_count"] = df["unique_labels"].apply(lambda x: len(x))

    # get complete labels in dataset
    labels = list(set(element for sublist in df["unique_labels"] for element in sublist)) + ["OTHER"]
    print(f"Show all {len(labels)} labels in the {str_} dataset:\n {labels}")
    print('\n')

    # One-Hot encoding for each label
    df[labels] = 0
    for index, data_ in df.iterrows():
        if data_.unique_labels_count != 0:
            for label in data_.unique_labels:
                df.at[index, label] = 1
        else:
            df.at[index, "OTHER"] = 1
    return df, labels

train_data, train_labels = onehot_labels(train_data, "train")
test_data, test_labels = onehot_labels(test_data, "test")

sum_train = train_data[train_labels].sum(axis=0)
sum_test = test_data[test_labels].sum(axis=0)

# ================Logs================
display(train_data.head(3))
print('\n')
display(test_data.head(3))
print('\n')

print("Train dataset labels distribution: \n", sum_train)
print('\n')
print("Test dataset labels distribution: \n", sum_test)
print('\n')

##  📈<span style='color: Orange;'><b>1.2 Plotly：Labels Distribution in Train/Test Data</b></span>


In [None]:
# Plotly
font_family = "Arial"
fig = make_subplots(rows=1, cols=2, specs=[[{'type':'domain'}, {'type':'domain'}]]) #, subplot_titles=("TrainSet", "TestSet")

# Define data to plot
labels1 = train_labels
values1 = sum_train

labels2 = test_labels
values2 = sum_test

# Add the first pie chart
fig.add_trace(go.Pie(labels=labels1, values=values1, name="Distribution of labels in TrainSet", textinfo='label+percent'),
              1, 1)

# Add the second pie chart
fig.add_trace(go.Pie(labels=labels2, values=values2, name="Distribution of labels in TestSet", textinfo='label+percent'),
              1, 2)

# Update Layout Settings
fig.update_layout(width=1000, height=700, margin=dict(l=75,r=50,b=50,t=50,),paper_bgcolor="WHITE")

fig.update_layout(title_text='Distribution of labels in TrainSet / TestSet',
                  font=dict(family=font_family,
                            size=10,
                            color="#75767A"),
                 )
fig.update_layout(title_font_family=font_family,
                  title_font_color='BLACK',
                  title_font_size=18,
                  title_x=0.5)

fig.update_layout(legend_title="",
                  showlegend=True,
                  legend=dict(x=0, y=1, orientation="h")
                 )

fig.add_annotation(text="TrainSet", x=0.18, y=0.05, xref="paper", yref="paper",
                   showarrow=False, font=dict(size=16))
fig.add_annotation(text="TestSet", x=0.82, y=0.05, xref="paper", yref="paper",
                   showarrow=False, font=dict(size=16))

fig.show()

#  ✏️<span style='color: Orange;'><b>2. Prepare Learning Prompts for Gemma</b></span>


###  🪄<span style='color: Orange;'><b>2.1. Prepare Answer Dictionary</b></span>

After adding the One-Hot columns, we also need to create a <span style='color: brown;'><b>value list for each label</b></span> to obtain the corresponding string. After obtaining the value lists, we can create an <span style='color: brown;'><b>"answers" colomn</b></span> to combine all the labels and their values into a dictionary.

In [None]:
# Get the string/values for each label
def value_of_labels(df, label_name):
    new_col = label_name+'_VALUE'
    df[new_col] = 'Nan'
    for row_id, row in df.iterrows():
        s = None
        true_ = []
        for token, label, space in zip(row["tokens"], row["labels"], row["trailing_whitespace"]):
            if label == "B-"+label_name:
                if s is not None:
                    if s.strip() not in true_:
                        true_.append(s.strip())
                space = " " if space else ""
                s = token + space
            if label == "I-"+label_name:
                space = " " if space else ""
                s += token + space
        if s is not None:
            if s.strip() not in true_:
                true_.append(s.strip())
        df.at[row_id, new_col] = true_
    return df

# Get the list of values of each label
for label_name in train_labels[:-1]: #exclude OTHER
    value_of_labels(train_data, label_name)
    value_of_labels(test_data, label_name)

train_data['answers'] = train_data.apply(lambda row: {"EMAIL": row['EMAIL_VALUE'],
                                                        "ID_NUM": row['ID_NUM_VALUE'],
                                                        "NAME_STUDENT": row['NAME_STUDENT_VALUE'],
                                                        "PHONE_NUM": row['PHONE_NUM_VALUE'],
                                                        "STREET_ADDRESS": row['STREET_ADDRESS_VALUE'],
                                                        "URL_PERSONAL": row['URL_PERSONAL_VALUE'],
                                                        "USERNAME": row['USERNAME_VALUE']}, axis=1)
test_data['answers'] = test_data.apply(lambda row: {"EMAIL": row['EMAIL_VALUE'],
                                                        "ID_NUM": row['ID_NUM_VALUE'],
                                                        "NAME_STUDENT": row['NAME_STUDENT_VALUE'],
                                                        "PHONE_NUM": row['PHONE_NUM_VALUE'],
                                                        "STREET_ADDRESS": row['STREET_ADDRESS_VALUE'],
                                                        "URL_PERSONAL": row['URL_PERSONAL_VALUE'],
                                                        "USERNAME": row['USERNAME_VALUE']}, axis=1)

# ================Logs================
display(train_data.head(3))
display(test_data.head(3))

print(f"Show an example of train answer:\n {train_data.answers[0]} ")
print('\n')
print(f"Show an example of test answer:\n {test_data.answers[0]} ")
print('\n')

###  🪄<span style='color: Orange;'><b>2.2. Prompt for Zero-Shot Learning</b></span>

<span style='color: green;'><b>Zero-Shot Learning: Only providing context to define specific tasks, </b></span> without providing any examples beyond that.

Define the <span style='color: brown;'><b>Context</b></span>:
<blockquote>

Please identify the named entities in the text below, including student's name(NAME_STUDENT), ID(ID_NUM), email address(EMAIL), phone number(PHONE_NUM), street address(STREET_ADDRESS), personal URL(URL_PERSONAL), and username(USERNAME).

List each category and its entities in JSON format: {'EMAIL': [], 'ID_NUM': [], 'NAME_STUDENT': [], 'PHONE_NUM': [], 'STREET_ADDRESS': [], 'URL_PERSONAL': [], 'USERNAME': []}

</blockquote>
    
Define the <span style='color: brown;'><b>template of the zeroshot_prompt</b></span>: 
<blockquote>
    Context + Text + Answer.
</blockquote>

In [None]:
# Generate zero-shot example's prompt
zeroshot_prompt = "\n\nContext:\n{Context}\n\nText:\n{Text}\n\nAnswer:\n{Answer}"

context_zeroshot = f"""Please identify the named entities in the text below, including student's name(NAME_STUDENT), ID(ID_NUM), email address(EMAIL), 
phone number(PHONE_NUM), street address(STREET_ADDRESS), personal url(URL_PERSONAL) and user name(USERNAME). List each category and its entities in json format.
{{'EMAIL': [], 'ID_NUM': [], 'NAME_STUDENT': [], 'PHONE_NUM': [], 'STREET_ADDRESS': [], 'URL_PERSONAL': [], 'USERNAME': []}}"""

# Generate test example's prompts
test_data['zeroshot_prompts'] = test_data.apply(lambda row: zeroshot_prompt.format(
                                                                            Context=context_zeroshot.strip(), 
                                                                            Text=' '.join(row["tokens"]), #[:train_token_length_in_prompt]
                                                                            Answer=""), axis=1)

# Check prompts
def colorize_test_text(text):
    for word, color in zip(["Context","Text","Answer"], 
                           ["blue","red","green"]):
        text = text.replace(f"\n\n{word}:", f"\n\n<span style='color: {color};'><b>{word}:</b></span>")
    return text

def display_test_prompt(sample):
    # Give colors to Text and Answer
    sample = colorize_test_text(sample)
    # Show sample in markdown
    display(Markdown(sample))

for i in range(1):
    sample = test_data['zeroshot_prompts'].tolist()[i]
    display_test_prompt(sample)
    print('\n')

###  🪄<span style='color: Orange;'><b>2.3. Prompt for One-Shot Learning</b></span>

<span style='color: green;'><b>One-Shot Learning: Providing context to define specific tasks, and one example (include Text + Answer).</b></span>

Define the <span style='color: brown;'><b>Context</b></span>:
<blockquote>

Please identify the named entities in the text below, including student's name(NAME_STUDENT), ID(ID_NUM), email address(EMAIL), phone number(PHONE_NUM), street address(STREET_ADDRESS), personal URL(URL_PERSONAL), and username(USERNAME).

List each category and its entities in JSON format: {'EMAIL': [], 'ID_NUM': [], 'NAME_STUDENT': [], 'PHONE_NUM': [], 'STREET_ADDRESS': [], 'URL_PERSONAL': [], 'USERNAME': []}

</blockquote>
    
Define the <span style='color: brown;'><b>template of the oneshot_prompt</b></span>: 
<blockquote>
    Context + Example1(Text+Answer) + Text + Answer.
</blockquote>

In [None]:
# Gather 1 example：
# Here I have chosen row indexe=2501, which contains all the labels and has the fewest tokens.
# train_data[train_data['unique_labels_count']==7].sort_values(by=['token_nums'])
row_oneshot = [2501] 
df_examples_oneshot = train_data.loc[row_oneshot]

# Generate example's prompt
train_template_prompt = "\n\nText:\n{Text}\n\nAnswer:\n{Answer}"

# Generate example's prompts
df_examples_oneshot['prompts'] = df_examples_oneshot.apply(lambda row: train_template_prompt.format(
                                                                    Text=' '.join(row["tokens"]), #[:train_token_length_in_prompt]), 
                                                                    Answer=row["answers"]), axis=1)

# Check prompts
def colorize_text(text):
    for word, color in zip(["Text", "Answer"], ["red", "green"]):
        text = text.replace(f"\n\n{word}:", f"\n\n<span style='color: {color};'><b>{word}:</b></span>")
    return text

def display_prompt(sample):
    # Give colors to Text and Answer
    sample = colorize_text(sample)
    # Show sample in markdown
    display(Markdown(sample))

for i in range(1):
    sample = df_examples_oneshot['prompts'].tolist()[i]
    display_prompt(sample)
    print('\n')

In [None]:
# Generate one-shot test's prompt
oneshot_prompt = "\n\nContext:\n{Context}\n\nExample1:\n{Example1}\n\nText:\n{Text}\n\nAnswer:\n{Answer}"

example_prompts_oneshot = df_examples_oneshot['prompts'].tolist()

# Generate test example's prompts
test_data['oneshot_prompts'] = test_data.apply(lambda row: oneshot_prompt.format(
                                                                            Context=context_zeroshot.strip(), 
                                                                            Example1=example_prompts_oneshot[0], 
                                                                            Text=' '.join(row["tokens"]), #[:train_token_length_in_prompt]
                                                                            Answer=""), axis=1)

# Check prompts
def colorize_test_text_oneshot(text):
    for word, color in zip(["Context","Example1","Text","Answer"], 
                           ["blue","orange","red","green"]):
        text = text.replace(f"\n\n{word}:", f"\n\n<span style='color: {color};'><b>{word}:</b></span>")
    return text

def display_test_prompt_oneshot(sample):
    # Give colors to Text and Answer
    sample = colorize_test_text_oneshot(sample)
    # Show sample in markdown
    display(Markdown(sample))

for i in range(1):
    sample = test_data['oneshot_prompts'].tolist()[i]
    display_test_prompt_oneshot(sample)
    print('\n')

###  🪄<span style='color: Orange;'><b>2.4. Prompt for Three-Shot Learning</b></span>

<span style='color: green;'><b>Three-Shot Learning: Providing context to define specific tasks, and three examples (include Text + Answer).</b></span>

Define the <span style='color: brown;'><b>Context</b></span>:
<blockquote>

Please identify the named entities in the text below, including student's name(NAME_STUDENT), ID(ID_NUM), email address(EMAIL), phone number(PHONE_NUM), street address(STREET_ADDRESS), personal URL(URL_PERSONAL), and username(USERNAME).

List each category and its entities in JSON format: {'EMAIL': [], 'ID_NUM': [], 'NAME_STUDENT': [], 'PHONE_NUM': [], 'STREET_ADDRESS': [], 'URL_PERSONAL': [], 'USERNAME': []}

</blockquote>
    
Define the <span style='color: brown;'><b>template of the threeshot_prompt</b></span>: 
<blockquote>
    Context + Example1(Text+Answer) + Example2(Text+Answer) + Example3(Text+Answer) + Text + Answer.
</blockquote>

In [None]:
# Gather 3 example：
# You can choose the row indexe as you like
# Here I have chosen row indexe=1319, which contains 0 label and has the fewest tokens.
#                    row indexe=2049, which contains 4 labels and has the fewest tokens.
#                    row indexe=2501, which contains 7 labels and has the fewest tokens.

row_3shot = [1319,2049,2501] 
df_examples_3shot = train_data.loc[row_3shot]

# Generate example's prompt
train_template_prompt = "\n\nText:\n{Text}\n\nAnswer:\n{Answer}"

# Generate example's prompts
df_examples_3shot['prompts'] = df_examples_3shot.apply(lambda row: train_template_prompt.format(
                                                                    Text=' '.join(row["tokens"]), #[:train_token_length_in_prompt]), 
                                                                    Answer=row["answers"]), axis=1)

# Generate three-shot test's prompt
prompt_3shot = "\n\nContext:\n{Context}\n\nExample1:\n{Example1}\n\nExample2:\n{Example2}\n\nExample3:\n{Example3}\n\nText:\n{Text}\n\nAnswer:\n{Answer}"

context = f"""Please identify the named entities in the text below, including student's name(NAME_STUDENT), ID(ID_NUM), email address(EMAIL), 
phone number(PHONE_NUM), street address(STREET_ADDRESS), personal url(URL_PERSONAL) and user name(USERNAME). List each category and its entities in json format."""

example_prompts_3shot = df_examples_3shot['prompts'].tolist()

# Generate test example's prompts
test_data['prompt_3shot'] = test_data.apply(lambda row: prompt_3shot.format(
                                                                            Context=context_zeroshot.strip(), 
                                                                            Example1=example_prompts_3shot[0], 
                                                                            Example2=example_prompts_3shot[1], 
                                                                            Example3=example_prompts_3shot[2], 
                                                                            Text=' '.join(row["tokens"]), #[:train_token_length_in_prompt]
                                                                            Answer=""), axis=1)

In [None]:
# for i in range(1):
#     sample = df_examples_3shot['prompts'].tolist()[i]
#     display_prompt(sample)
#     print('\n')

# # Check prompts
# def colorize_test_text_fewshot(text):
#     for word, color in zip(["Context","Example1","Example2","Example3","Text","Answer"], 
#                            ["blue","orange","orange","orange","red","green"]):
#         text = text.replace(f"\n\n{word}:", f"\n\n<span style='color: {color};'><b>{word}:</b></span>")
#     return text

# def display_test_prompt_fewshot(sample):
#     sample = colorize_test_text_fewshot(sample)
#     display(Markdown(sample))

# for i in range(1):
#     sample = test_data['prompt_3shot'].tolist()[i]
#     display_test_prompt_fewshot(sample)
#     print('\n')

###  🪄<span style='color: Orange;'><b>2.5. Prompt for Five-Shot Learning</b></span>

<span style='color: green;'><b>Five-Shot Learning: Providing context to define specific tasks, and five examples (include Text + Answer).</b></span>

Define the <span style='color: brown;'><b>Context</b></span>:
<blockquote>

Please identify the named entities in the text below, including student's name(NAME_STUDENT), ID(ID_NUM), email address(EMAIL), phone number(PHONE_NUM), street address(STREET_ADDRESS), personal URL(URL_PERSONAL), and username(USERNAME).

List each category and its entities in JSON format: {'EMAIL': [], 'ID_NUM': [], 'NAME_STUDENT': [], 'PHONE_NUM': [], 'STREET_ADDRESS': [], 'URL_PERSONAL': [], 'USERNAME': []}

</blockquote>
    
Define the <span style='color: brown;'><b>template of the fiveshot_prompt</b></span>: 
<blockquote>
    Context + Example1(Text+Answer) + Example2(Text+Answer) + Example3(Text+Answer) + Example4(Text+Answer) + Example5(Text+Answer) + Text + Answer.
</blockquote>

In [None]:
# Gather 5 example：
# You can choose the row indexe as you like
# Here I have chosen row indexe=1319 which contains 0 label and has the fewest tokens.
#                    row indexe=2049 which contains 4 labels and has the fewest tokens.
#                    row indexe=2501,2593,2680 which contain 7 labels and fewest tokens.

row_5shot = [1319,2049,2501,2593,2680] 
df_examples_5shot = train_data.loc[row_5shot]

# Generate example's prompt
train_template_prompt = "\n\nText:\n{Text}\n\nAnswer:\n{Answer}"

# Generate example's prompts
df_examples_5shot['prompts'] = df_examples_5shot.apply(lambda row: train_template_prompt.format(
                                                                    Text=' '.join(row["tokens"]), #[:train_token_length_in_prompt]), 
                                                                    Answer=row["answers"]), axis=1)

# Generate five-shot test's prompt
prompt_5shot = "\n\nContext:\n{Context}\n\nExample1:\n{Example1}\n\nExample2:\n{Example2}\n\nExample3:\n{Example3}\n\nExample4:\n{Example4}\n\nExample5:\n{Example5}\n\nText:\n{Text}\n\nAnswer:\n{Answer}"

context = f"""Please identify the named entities in the text below, including student's name(NAME_STUDENT), ID(ID_NUM), email address(EMAIL), 
phone number(PHONE_NUM), street address(STREET_ADDRESS), personal url(URL_PERSONAL) and user name(USERNAME). List each category and its entities in json format."""

example_prompts_5shot = df_examples_5shot['prompts'].tolist()

# Generate test example's prompts
test_data['prompt_5shot'] = test_data.apply(lambda row: prompt_5shot.format(
                                                                            Context=context_zeroshot.strip(), 
                                                                            Example1=example_prompts_5shot[0], 
                                                                            Example2=example_prompts_5shot[1], 
                                                                            Example3=example_prompts_5shot[2],
                                                                            Example4=example_prompts_5shot[3],
                                                                            Example5=example_prompts_5shot[4],
                                                                            Text=' '.join(row["tokens"]), #[:train_token_length_in_prompt]
                                                                            Answer=""), axis=1)

In [None]:
# for i in range(1):
#     sample = df_examples_5shot['prompts'].tolist()[i]
#     display_prompt(sample)
#     print('\n')

# # Check prompts
# def colorize_test_text_5shot(text):
#     for word, color in zip(["Context","Example1","Example2","Example3","Example4","Example5","Text","Answer"], 
#                            ["blue","orange","orange","orange","orange","orange","red","green"]):
#         text = text.replace(f"\n\n{word}:", f"\n\n<span style='color: {color};'><b>{word}:</b></span>")
#     return text

# def display_test_prompt_5shot(sample):
#     # Give colors to Text and Answer
#     sample = colorize_test_text_5shot(sample)
#     # Show sample in markdown
#     display(Markdown(sample))

# for i in range(1):
#     sample = test_data['prompt_5shot'].tolist()[i]
#     display_test_prompt_5shot(sample)
#     print('\n')

###  🪄<span style='color: Orange;'><b>2.6. Prompt for Seven-Shot Learning</b></span>

<span style='color: green;'><b>Seven-Shot Learning: Providing context to define specific tasks, and seven examples (include Text + Answer).</b></span>

Define the <span style='color: brown;'><b>Context</b></span>:
<blockquote>

Please identify the named entities in the text below, including student's name(NAME_STUDENT), ID(ID_NUM), email address(EMAIL), phone number(PHONE_NUM), street address(STREET_ADDRESS), personal URL(URL_PERSONAL), and username(USERNAME).

List each category and its entities in JSON format: {'EMAIL': [], 'ID_NUM': [], 'NAME_STUDENT': [], 'PHONE_NUM': [], 'STREET_ADDRESS': [], 'URL_PERSONAL': [], 'USERNAME': []}

</blockquote>
    
Define the <span style='color: brown;'><b>template of the sevenshot_prompt</b></span>: 
<blockquote>
    Context + Example1(Text+Answer) + Example2(Text+Answer) + Example3(Text+Answer) + Example4(Text+Answer) + Example5(Text+Answer) + Example6(Text+Answer) + Example7(Text+Answer) + Text + Answer.
</blockquote>

In [None]:
# Gather 7 example：
# You can choose the row indexe as you like
# Here I have chosen row indexe=1319 which contains 0 label and has the fewest tokens.
#                    row indexe=2049,1556 which contain 4 labels and has the fewest tokens.
#                    row indexe=2501,2593,2680,2051 which contain 7 labels and fewest tokens.

row_7shot = [1319,2049,1556,2501,2593,2680,2051] 
df_examples_7shot = train_data.loc[row_7shot]

# Generate example's prompt
train_template_prompt = "\n\nText:\n{Text}\n\nAnswer:\n{Answer}"

# Generate example's prompts
df_examples_7shot['prompts'] = df_examples_7shot.apply(lambda row: train_template_prompt.format(
                                                                    Text=' '.join(row["tokens"]), #[:train_token_length_in_prompt]), 
                                                                    Answer=row["answers"]), axis=1)

# Generate seven-shot test's prompt
prompt_7shot = "\n\nContext:\n{Context}\n\nExample1:\n{Example1}\n\nExample2:\n{Example2}\n\nExample3:\n{Example3}\n\nExample4:\n{Example4}\n\nExample5:\n{Example5}\n\nExample6:\n{Example6}\n\nExample7:\n{Example7}\n\nText:\n{Text}\n\nAnswer:\n{Answer}"

context = f"""Please identify the named entities in the text below, including student's name(NAME_STUDENT), ID(ID_NUM), email address(EMAIL), 
phone number(PHONE_NUM), street address(STREET_ADDRESS), personal url(URL_PERSONAL) and user name(USERNAME). List each category and its entities in json format."""

example_prompts_7shot = df_examples_7shot['prompts'].tolist()

# Generate test example's prompts
test_data['prompt_7shot'] = test_data.apply(lambda row: prompt_7shot.format(
                                                                            Context=context_zeroshot.strip(), 
                                                                            Example1=example_prompts_7shot[0], 
                                                                            Example2=example_prompts_7shot[1], 
                                                                            Example3=example_prompts_7shot[2],
                                                                            Example4=example_prompts_7shot[3],
                                                                            Example5=example_prompts_7shot[4],
                                                                            Example6=example_prompts_7shot[5],
                                                                            Example7=example_prompts_7shot[6],
                                                                            Text=' '.join(row["tokens"]), #[:train_token_length_in_prompt]
                                                                            Answer=""), axis=1)

In [None]:
# for i in range(1):
#     sample = df_examples_7shot['prompts'].tolist()[i]
#     display_prompt(sample)
#     print('\n')

# # Check prompts
# def colorize_test_text_7shot(text):
#     for word, color in zip(["Context","Example1","Example2","Example3","Example4","Example5","Example6","Example7","Text","Answer"], 
#                            ["blue","orange","orange","orange","orange","orange","orange","orange","red","green"]):
#         text = text.replace(f"\n\n{word}:", f"\n\n<span style='color: {color};'><b>{word}:</b></span>")
#     return text

# def display_test_prompt_7shot(sample):
#     # Give colors to Text and Answer
#     sample = colorize_test_text_7shot(sample)
#     # Show sample in markdown
#     display(Markdown(sample))

# for i in range(1):
#     sample = test_data['prompt_7shot'].tolist()[i]
#     display_test_prompt_7shot(sample)
#     print('\n')

###  🪄<span style='color: Orange;'><b>2.7. Prompt for Nine-Shot Learning</b></span>

<span style='color: green;'><b>Nine-Shot Learning: Providing context to define specific tasks, and nine examples (include Text + Answer).</b></span>

Define the <span style='color: brown;'><b>Context</b></span>:
<blockquote>

Please identify the named entities in the text below, including student's name(NAME_STUDENT), ID(ID_NUM), email address(EMAIL), phone number(PHONE_NUM), street address(STREET_ADDRESS), personal URL(URL_PERSONAL), and username(USERNAME).

List each category and its entities in JSON format: {'EMAIL': [], 'ID_NUM': [], 'NAME_STUDENT': [], 'PHONE_NUM': [], 'STREET_ADDRESS': [], 'URL_PERSONAL': [], 'USERNAME': []}

</blockquote>
    
Define the <span style='color: brown;'><b>template of the sevenshot_prompt</b></span>: 
<blockquote>
    Context + Example1(Text+Answer) + Example2(Text+Answer) + Example3(Text+Answer) + Example4(Text+Answer) + Example5(Text+Answer) + Example6(Text+Answer) + Example7(Text+Answer) + Example8(Text+Answer) + Example9(Text+Answer) + Text + Answer.
</blockquote>

In [None]:
# Gather 9 example：
# You can choose the row indexe as you like
# Here I have chosen row indexe=1319,1340 which contain 0 label and has the fewest tokens.
#                    row indexe=2049,1556 which contain 4 labels and has the fewest tokens.
#                    row indexe=2501,2593,2680,2051,2765 which contain 7 labels and fewest tokens.

row_9shot = [1319,1340,2049,1556,2501,2593,2680,2051,2765] 
df_examples_9shot = train_data.loc[row_9shot]

# Generate example's prompt
train_template_prompt = "\n\nText:\n{Text}\n\nAnswer:\n{Answer}"

# Generate example's prompts
df_examples_9shot['prompts'] = df_examples_9shot.apply(lambda row: train_template_prompt.format(
                                                                    Text=' '.join(row["tokens"]), #[:train_token_length_in_prompt]), 
                                                                    Answer=row["answers"]), axis=1)

# Generate few-shot test's prompt
prompt_9shot = "\n\nContext:\n{Context}\n\nExample1:\n{Example1}\n\nExample2:\n{Example2}\n\nExample3:\n{Example3}\n\nExample4:\n{Example4}\n\nExample5:\n{Example5}\n\nExample6:\n{Example6}\n\nExample7:\n{Example7}\n\nExample8:\n{Example8}\n\nExample9:\n{Example9}\n\nText:\n{Text}\n\nAnswer:\n{Answer}"

context = f"""Please identify the named entities in the text below, including student's name(NAME_STUDENT), ID(ID_NUM), email address(EMAIL), 
phone number(PHONE_NUM), street address(STREET_ADDRESS), personal url(URL_PERSONAL) and user name(USERNAME). List each category and its entities in json format."""

example_prompts_9shot = df_examples_9shot['prompts'].tolist()


# Generate test example's prompts
test_data['prompt_9shot'] = test_data.apply(lambda row: prompt_9shot.format(
                                                                            Context=context_zeroshot.strip(), 
                                                                            Example1=example_prompts_9shot[0], 
                                                                            Example2=example_prompts_9shot[1], 
                                                                            Example3=example_prompts_9shot[2],
                                                                            Example4=example_prompts_9shot[3],
                                                                            Example5=example_prompts_9shot[4],
                                                                            Example6=example_prompts_9shot[5],
                                                                            Example7=example_prompts_9shot[6],
                                                                            Example8=example_prompts_9shot[7],
                                                                            Example9=example_prompts_9shot[8],
                                                                            Text=' '.join(row["tokens"]), #[:train_token_length_in_prompt]
                                                                            Answer=""), axis=1)

In [None]:
# for i in range(1):
#     sample = df_examples_9shot['prompts'].tolist()[i]
#     display_prompt(sample)
#     print('\n')

# # Check prompts
# def colorize_test_text_9shot(text):
#     for word, color in zip(["Context","Example1","Example2","Example3","Example4","Example5","Example6","Example7","Example8","Example9","Text","Answer"], 
#                            ["blue","orange","orange","orange","orange","orange","orange","orange","orange","orange","red","green"]):
#         text = text.replace(f"\n\n{word}:", f"\n\n<span style='color: {color};'><b>{word}:</b></span>")
#     return text

# def display_test_prompt_9shot(sample):
#     sample = colorize_test_text_9shot(sample)
#     display(Markdown(sample))

# for i in range(1):
#     sample = test_data['prompt_9shot'].tolist()[i]
#     display_test_prompt_9shot(sample)
#     print('\n')

# 🤖<span style='color: Orange;'><b>3. Load Gemma_2b_en Model</b></span>


On Kaggle, [Gemma series](https://www.kaggle.com/models/keras/gemma) models are readily available. 

You can choose between <span style='color: brown;'><b>Gemma 2b</b></span> or <span style='color: brown;'><b>Gemma 7b</b></span> for your usage.

The specified sampling strategy while text generation is  <span style='color: brown;'><b>TopKSampler</b></span>, which selects the top k candidates with the highest probabilities from each position's candidate words as output. This ensures that the generated text is more diverse and of higher quality.

In [None]:
# You can active 'mixed_bfloat16' to reduce the learning time, it will also reduce the performance.
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

gemma_tokenizer = keras_nlp.models.GemmaTokenizer.from_preset(model_name)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset(model_name)
gemma_lm.summary()


sampler = keras_nlp.samplers.TopKSampler(k=5, seed=seed)
gemma_lm.compile(sampler=sampler)

gemma_lm.dtype_policy

# 📖<span style='color: Orange;'><b>4. N shot Learning</b></span>

During the learning period, I will record the learning time for each paradigm in the <span style='color: brown;'><b>all_test_time</b></span> variable.

###  📖<span style='color: Orange;'><b>4.1 Zero Shot Learning</b></span>

In [None]:
all_test_time = []

In [None]:
start_time_zeroshot = time.time()

zeroshot_outs = []
for prompt in tqdm(test_data["zeroshot_prompts"]):
    out = gemma_lm.generate(prompt, max_length=max_length_pred_zeroshot)
    zeroshot_outs.append(out)
    
end_time_zeroshot = time.time() 
elapsed_time_zeroshot = end_time_zeroshot - start_time_zeroshot
all_test_time.append(elapsed_time_zeroshot)

print(f"Execution Time：\n{elapsed_time_zeroshot} seconds, {round(elapsed_time_zeroshot/60, 2)} min, , {round(elapsed_time_zeroshot/3600, 2)} h")

###  📖<span style='color: Orange;'><b>4.2 One Shot Learning</b></span>

In [None]:
start_time_oneshot = time.time()

oneshot_outs = []
for prompt in tqdm(test_data["oneshot_prompts"]):
    out = gemma_lm.generate(prompt, max_length=max_length_pred_oneshot)
    oneshot_outs.append(out)
    
end_time_oneshot = time.time()
elapsed_time_oneshot = end_time_oneshot - start_time_oneshot
all_test_time.append(elapsed_time_oneshot)

print(f"Execution Time：\n{elapsed_time_oneshot} seconds, {round(elapsed_time_oneshot/60, 2)} min, , {round(elapsed_time_oneshot/3600, 2)} h")

###  📖<span style='color: Orange;'><b>4.3 Three Shot Learning</b></span>

In [None]:
start_time_3shot = time.time()

threeshot_outs = []
for prompt in tqdm(test_data["prompt_3shot"]):
    out = gemma_lm.generate(prompt, max_length=max_length_pred_3shot)
    threeshot_outs.append(out)
    
end_time_3shot = time.time()
elapsed_time_3shot = end_time_3shot - start_time_3shot
all_test_time.append(elapsed_time_3shot)

print(f"Execution Time：\n{elapsed_time_3shot} seconds, {round(elapsed_time_3shot/60, 2)} min, , {round(elapsed_time_3shot/3600, 2)} h")

###  📖<span style='color: Orange;'><b>4.4 Five Shot Learning</b></span>

In [None]:
start_time_5shot = time.time()

fiveshot_outs = []
for prompt in tqdm(test_data["prompt_5shot"]):
    out = gemma_lm.generate(prompt, max_length=max_length_pred_5shot)
    fiveshot_outs.append(out)
    
end_time_5shot = time.time()
elapsed_time_5shot = end_time_5shot - start_time_5shot
all_test_time.append(elapsed_time_5shot)

print(f"Execution Time：\n{elapsed_time_5shot} seconds, {round(elapsed_time_5shot/60, 2)} min, , {round(elapsed_time_5shot/3600, 2)} h")

###  📖<span style='color: Orange;'><b>4.5 Seven Shot Learning</b></span>

In [None]:
start_time_7shot = time.time()

sevenshot_outs = []
for prompt in tqdm(test_data["prompt_7shot"]):
    out = gemma_lm.generate(prompt, max_length=max_length_pred_7shot)
    sevenshot_outs.append(out)
    
end_time_7shot = time.time()
elapsed_time_7shot = end_time_7shot - start_time_7shot
all_test_time.append(elapsed_time_7shot)

print(f"Execution Time：\n{elapsed_time_7shot} seconds, {round(elapsed_time_7shot/60, 2)} min, , {round(elapsed_time_7shot/3600, 2)} h")

###  📖<span style='color: Orange;'><b>4.6 Nine Shot Learning</b></span>

In [None]:
start_time_9shot = time.time()

nineshot_outs = []
for prompt in tqdm(test_data["prompt_9shot"]):
    out = gemma_lm.generate(prompt, max_length=max_length_pred_9shot)
    nineshot_outs.append(out)
    
end_time_9shot = time.time()
elapsed_time_9shot = end_time_9shot - start_time_9shot
all_test_time.append(elapsed_time_9shot)

print(f"Execution Time：\n{elapsed_time_9shot} seconds, {round(elapsed_time_9shot/60, 2)} min, , {round(elapsed_time_9shot/3600, 2)} h")

# 🌟<span style='color: Orange;'><b> 5. Post-processing of Predictions</b></span>

In this section, I have referenced some of the function definitions.

<span style='color: green;'><b>The list of references of Notebook is as follows:</b></span>
<blockquote>
    
[piidd-let-s-go-higher](https://www.kaggle.com/code/hyunsoolee1010/piidd-let-s-go-higher)
    
[rule-based-approach](https://www.kaggle.com/code/emiz6413/rule-based-approach)
</blockquote>

####  ⚠️<span style='color: Orange;'><b> You should run this block if you passed the learning phase.</b></span>

In [None]:
# To get the outputs of different learning paradigms.

# my_dicts = json.load(open("/kaggle/input/temp-test-gemma-output/my_dict4.json"))
# zeroshot_outs = my_dicts['zeroshot_outs']
# oneshot_outs = my_dicts['oneshot_outs']
# threeshot_outs = my_dicts['threeshot_outs']
# fiveshot_outs = my_dicts['fiveshot_outs']
# sevenshot_outs = my_dicts['sevenshot_outs']
# nineshot_outs = my_dicts['nineshot_outs']

# zeroshot_time = my_dicts["zeroshot_time"]
# oneshot_time = my_dicts["oneshot_time"]
# threeshot_time = my_dicts["threeshot_time"]
# fiveshot_time = my_dicts["fiveshot_time"]
# sevenshot_time = my_dicts["sevenshot_time"]
# nineshot_time = my_dicts["nineshot_time"]

# all_test_time = [zeroshot_time,oneshot_time,threeshot_time,fiveshot_time,sevenshot_time, nineshot_time]

In [None]:
# Used as tokenizer when find_span()
nlp = English()

# This function is used to match all the indices in the original list of tokens (document)
# where the token of the answer (target) is located.
# ref：https://www.kaggle.com/code/emiz6413/rule-based-approach
def find_span(target, document):
    # target: the token list of the matched phone number (after tokenizer processing).
    # document: the token list of all text.
    
    # idx is used for character counting
    idx = 0
    spans = []
    span = []

    for i, token in enumerate(document):
        if token != target[idx]:
            idx = 0
            span = []
            continue
        # print(f"token id:{i}, token str: {token} ")
        span.append(i)
        idx += 1
        if idx == len(target):
            if span not in spans:
                spans.append(span)
                span = []
                idx = 0
                continue
            else:
                span = []
                idx = 0
    return spans

def get_answers_from_prediction(pred, type_train):
    # This function is used to extract the answer part from the raw output of the model
    n = int(type_train.split("-")[0]) # get n for "n-shot"
    if n != 0: # not zero-shot
        split_pred = pred.split(f"Example{n}:\n\n")[1]
        all_res = split_pred.split("Answer:\n")[2].strip()
    else: # zero-shot
        all_res = pred.split("Answer:\n")[1].strip()

    try:
        all_res = all_res.split("\nText:")[0].strip()
    except:
        pass
    
    # Post-process of answers
    # Change '' to "", remove \n, \t, " ", 
    # Add ]} if not in answer, cuz we assume the answer should in json format, and want it is closed
    text = str(all_res.replace("'", '"'))
    text = str(text.replace("\n", ''))
    text = str(text.replace("\t", ''))
    text = str(text.replace(" ", ''))
    if "]}" not in text: text+= "]}"

    return text

def porcess_matched_list(match):
    # m is a list
    matched = [m for m in match if m!=""]
    list_match = []
    for ele in matched:
        # “email@com”, "eamil2@com" -> '', 'email@com', ', ','email2@com', ''
        ele_splitted = ele.split('"')
        # print("ele_splitted", ele_splitted)
        ele_splitted2 = []
        for element in ele_splitted:
            # email@com,eamil2@com -> 'email@com', 'email2@com'
            ele_splitted2 += element.split(",")
        # print("ele_splitted2", ele_splitted2)
        
        # clean incorrect elements/text in the final answer
        ans = [element.strip() for element in ele_splitted2 if element not in [',', '', ' ',', ', 'VALUE', 'value', '{', '}']]
        list_match+=ans
    return list_match

# This function is used to extract the answer from the raw output of the model and convert to dict
def postprocess_output_to_dict(y_preds_raw, type_train):
    # Register answers of all examples in res_dicts
    res_dicts = []
    for i, pred in enumerate(y_preds_raw):
        text = get_answers_from_prediction(pred, type_train)
        # Register answers of one example in res_dict
        res_dict = {}
        for l in test_labels[:-1]:
            # Define as much as possible the rules to match answers generated by Gemma
            pattern = f"\"*{l}\"*:\[(.*?)\]" #"NAME_STUDENT":[...]
            pattern2 = f"\"*{l}\"*,\[(.*?)\]" #"NAME_STUDENT",[...]
            pattern3 = f'\"*{l}\"*:"(.*?)"'#"NAME_STUDENT":"..."
            match = re.findall(pattern, text)
            match2 = re.findall(pattern2, text)
            match3 = re.findall(pattern3, text)

            is_matched = False
            if match != []:
                list_match = porcess_matched_list(match)
                res_dict[l] = list_match 
                is_matched = True
            elif match2 != [] and not is_matched:
                list_match2 = porcess_matched_list(match2)
                res_dict[l] = list_match2
                is_matched = True
            elif match3 != [] and not is_matched:
                list_match3 = porcess_matched_list(match3)
                res_dict[l] = list_match3
                is_matched = True
            else:
                res_dict[l] = []
        res_dicts.append(res_dict)
    return res_dicts

# This function is used to prepare the new dictionary so that it contains the info from the original 
# example and get the tokens-str，token-index corresponding and BIO label to compute the f5 score.
# ref: https://www.kaggle.com/code/hyunsoolee1010/piidd-let-s-go-higher
def generate_output_df(res_dicts, df_test):
    real_ypreds = []
    for i, rd in enumerate(res_dicts):
        # If existe valid values in dict 
        if list(rd.values()) != [[] for i in range(len(test_labels[:-1]))]:
            _data = df_test.loc[i]
            for k in rd.keys():
                values = rd[k]
                for v in values:
                    if v not in [',', '', ' ',', ', 'VALUE', 'value', '{', '}']:
                        target = [t.text for t in nlp.tokenizer(v)] 
                        matched_spans = find_span(target, _data["tokens"])
                        # print("target: ", target)
                        # print("_data: ", _data["tokens"])
                        # print("matched_spans: ", matched_spans)
                        for matched_span in matched_spans:
                            for intermediate, token_idx in enumerate(matched_span):
                                prefix = "I" if intermediate else "B"
                                elem = {"document": _data["document"], 
                                       "token": token_idx, 
                                       "label": f"{prefix}-{k}", 
                                       "token_str": _data["tokens"][token_idx]}
                                if elem not in real_ypreds:
                                    real_ypreds.append(elem)
    return real_ypreds

# ref：https://www.kaggle.com/code/emiz6413/rule-based-approach
# I use the f5 score as it is used by the competition. 
def pii_fbeta_score(pred_df, gt_df, beta=5):
    """
    Parameters:
    - pred_df (DataFrame): DataFrame containing predicted PII labels.
    - gt_df (DataFrame): DataFrame containing ground truth PII labels.
    - beta (float): The beta parameter for the F-beta score, controlling the trade-off between precision and recall.

    Returns:
    - float: Micro F-beta score.
    """   
    df = pred_df.merge(gt_df,how='outer',on=['document',"token"],suffixes=('_pred','_gt'))

    df['cm'] = ""

    df.loc[df.label_gt.isna(),'cm'] = "FP"


    df.loc[df.label_pred.isna(),'cm'] = "FN"
    df.loc[(df.label_gt.notna()) & (df.label_gt!=df.label_pred),'cm'] = "FN"

    df.loc[(df.label_pred.notna()) & (df.label_gt.notna()) & (df.label_gt==df.label_pred),'cm'] = "TP"
    
    FP = (df['cm']=="FP").sum()
    FN = (df['cm']=="FN").sum()
    TP = (df['cm']=="TP").sum()

    s_micro = (1+(beta**2))*TP/(((1+(beta**2))*TP) + ((beta**2)*FN) + FP)

    return s_micro

In [None]:
# Get the ground truth.
# ref：https://www.kaggle.com/code/emiz6413/rule-based-approach
gt = []
for _, row in test_data.iterrows():
    for token_idx, (token, label) in enumerate(zip(row["tokens"], row["labels"])):
        if label == "O":
            continue
        gt.append(
            {"document": row["document"], "token": token_idx, "label": label}
        )
gt_df = pd.DataFrame(gt)
gt_df["row_id"] = gt_df.index
gt_df.head()

In [None]:
# Calculate the score.
f5s = []
types = ["0-shot","1-shot", "3-shot", "5-shot", "7-shot", "9-shot"]  # 
for i, pred in enumerate([zeroshot_outs, oneshot_outs, threeshot_outs, fiveshot_outs, sevenshot_outs, nineshot_outs]):
    res_dicts = postprocess_output_to_dict(pred, types[i])
    real_ypreds = generate_output_df(res_dicts, test_data)
    
    pred_df = pd.DataFrame(real_ypreds)
    pred_df = pred_df.sort_values(by=['token'])
    pred_df = pred_df.sort_values(by=['document'])
    pred_df = pred_df.reset_index(drop=True)
    pred_df["row_id"] = pred_df.index

    f5 = pii_fbeta_score(pred_df, gt_df, beta=5)
    print(f"F5 score of {types[i]}: {f5}")
    f5s.append(f5)

In [None]:
# Save the outputs.
my_dict = {'zeroshot_time': all_test_time[0], 
           'oneshot_time': all_test_time[1], 
           'threeshot_time': all_test_time[2],
           'fiveshot_time': all_test_time[3],
           'sevenshot_time': all_test_time[4],
           'nineshot_time': all_test_time[5],

           'zeroshot_f5': f5s[0],
           'oneshot_f5': f5s[1],
           'threeshot_f5': f5s[2],
           'fiveshot_f5': f5s[3],
           'sevenshot_f5': f5s[4],
           'nineshot_f5': f5s[5],
           
           "zeroshot_outs":zeroshot_outs,
           "oneshot_outs":oneshot_outs,
           "threeshot_outs":threeshot_outs,
           "fiveshot_outs":fiveshot_outs,
           "sevenshot_outs":sevenshot_outs,
           "nineshot_outs":nineshot_outs
           }

# Save dict
with open('/kaggle/working/my_dict.json', 'w') as f:
    json.dump(my_dict, f)

In [None]:
# Plot all f5 scores and learning time
import plotly.graph_objects as go
from plotly.subplots import make_subplots

font_family = "Arial"
fig = make_subplots(rows=2, cols=1) #, subplot_titles=("TrainSet", "TestSet")

# Define data to plot
x1 = types
y1 = f5s

x2 = types
y2 = all_test_time

# Add the first line chart to the first column
fig.add_trace(go.Scatter(x=x1, y=y1, mode='lines', name='F5-score'), 1, 1)
fig.add_trace(go.Scatter(x=x2, y=y2, mode='lines', name='Training-Time'), 2, 1)

# Update Layout Settings
fig.update_layout(width=800, height=800,paper_bgcolor="WHITE")

fig.update_layout(title_text='Performance on N-Shot Learning Paradigms',
                  font=dict(family=font_family,
                            size=10,
                            color="#75767A"),
                 )
fig.update_yaxes(title_text="F5 Score", row=1, col=1)
fig.update_yaxes(title_text="Learning Time (s)", row=2, col=1)

fig.update_layout(title_font_family=font_family,
                  title_font_color='BLACK',
                  title_font_size=15,
                  title_x=0.5)

fig.update_layout(legend_title="",
                  showlegend=True,
                  legend=dict(x=0.33, y=1.06, orientation="h")
                 )

fig.show()