<div style="padding: 0.5em; background-color: #1876d1; color: #fff; font-weight: bold; font-size: 1.4em;">
    [Approach 3]  Location Mention Recognition - Fine-tunning LLM
</div>

In this Jupyter notebook, we will use LLM to extract from X (Twitter formely) tweets Location Mention from Emergency Situation.

Step :
* Retreive dataset 
* Prepare prompt for fine-tunning
* Tested model: <span style="color: red;">01-Mistral-7B-Instruct v02</span>

---
<b>#Microsoft Learn Challenge, #Zindi, #Hamad Bin Khalifa University </b>

### **Importing Library**

In [1]:
# !pip install trl

In [1]:
!python --version

Python 3.11.8


python(12707) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


In [2]:
# general utils
import werpy
import numpy as np
import pandas as pd
import seaborn as sns
import os, sys, requests
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
pd.set_option('display.max_colwidth', 300)

# utils setup
current_directory = os.getcwd()
root_directory = os.path.abspath(os.path.join(current_directory, os.pardir))
sys.path.append(root_directory)

# logging
import wandb
os.environ["WANDB_NOTEBOOK_NAME"] = "fine-tunne--mistral-7B.ipynb"

# custom utils
from utils.io import LMR_XML_Scrapper
from utils.preprocessing import Preprocess

# hugging face utils
import torch
from datasets import Dataset
from dataclasses import dataclass
from trl import SFTConfig, SFTTrainer
from transformers.utils import PaddingStrategy
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union

### **Preparing Data**

We have to dwonload the dataset and format in such way it an be used by the model

In [2]:
LMR_XML_Scrapper(output_dir="../data/self_scrapped/raw-llm").run()

Processing dataset: california_wildfires_2018


Extracting Files : 100%|██████████| 3/3 [00:01<00:00,  2.17file/s]


Processing dataset: canada_wildfires_2016


Extracting Files : 100%|██████████| 3/3 [00:01<00:00,  2.37file/s]


Processing dataset: cyclone_idai_2019


Extracting Files : 100%|██████████| 3/3 [00:01<00:00,  2.14file/s]


Processing dataset: ecuador_earthquake_2016


Extracting Files : 100%|██████████| 3/3 [00:01<00:00,  2.43file/s]


Processing dataset: greece_wildfires_2018


Extracting Files : 100%|██████████| 3/3 [00:01<00:00,  2.51file/s]


Processing dataset: hurricane_dorian_2019


Extracting Files : 100%|██████████| 3/3 [00:01<00:00,  2.36file/s]


Processing dataset: hurricane_florence_2018


Extracting Files : 100%|██████████| 3/3 [00:01<00:00,  2.24file/s]


Processing dataset: hurricane_harvey_2017


Extracting Files : 100%|██████████| 3/3 [00:01<00:00,  2.29file/s]


Processing dataset: hurricane_irma_2017


Extracting Files : 100%|██████████| 3/3 [00:01<00:00,  2.43file/s]


Processing dataset: hurricane_maria_2017


Extracting Files : 100%|██████████| 3/3 [00:01<00:00,  2.34file/s]


Processing dataset: hurricane_matthew_2016


Extracting Files : 100%|██████████| 3/3 [00:01<00:00,  2.53file/s]


Processing dataset: italy_earthquake_aug_2016


Extracting Files : 100%|██████████| 3/3 [00:01<00:00,  2.60file/s]


Processing dataset: kaikoura_earthquake_2016


Extracting Files : 100%|██████████| 3/3 [00:01<00:00,  2.46file/s]


Processing dataset: kerala_floods_2018


Extracting Files : 100%|██████████| 3/3 [00:01<00:00,  2.15file/s]


Processing dataset: maryland_floods_2018


Extracting Files : 100%|██████████| 3/3 [00:01<00:00,  2.74file/s]


Processing dataset: midwestern_us_floods_2019


Extracting Files : 100%|██████████| 3/3 [00:01<00:00,  2.20file/s]


Processing dataset: pakistan_earthquake_2019


Extracting Files : 100%|██████████| 3/3 [00:01<00:00,  2.51file/s]


Processing dataset: puebla_mexico_earthquake_2017


Extracting Files : 100%|██████████| 3/3 [00:01<00:00,  2.33file/s]


Processing dataset: srilanka_floods_2017


Extracting Files : 100%|██████████| 3/3 [00:01<00:00,  2.81file/s]

Processing complete.





- Let concatenate out dataset

In [17]:
train_dfs = []
dev_dfs   = []
test_dfs  = []
path_dfs  = "../data/self_scrapped/raw-llm"
for filename in os.listdir(path_dfs):
    if filename.endswith(".csv"):
        file_path = os.path.join(path_dfs, filename)
        if filename.startswith("train"):
            df = pd.read_csv(file_path)
            train_dfs.append(df)
        elif filename.startswith("dev"):
            df = pd.read_csv(file_path)
            dev_dfs.append(df)
        elif filename.startswith("test_unlabeled"):
            df = pd.read_csv(file_path)
            test_dfs.append(df)

df_train = pd.concat(train_dfs, ignore_index=True) if train_dfs else pd.DataFrame()
df_test  = pd.concat(test_dfs, ignore_index=True) if test_dfs else pd.DataFrame()
df_dev   = pd.concat(dev_dfs, ignore_index=True) if dev_dfs else pd.DataFrame()

df_train.to_csv("../data/transformed/lmr_train.csv")
#df_test.to_csv("../data/transformed/lmr_test.csv")
df_dev.to_csv("../data/transformed/lmr_dev.csv")

print("TRAIN SHAPE: ", df_train.shape)
print("TEST  SHAPE: ", df_test.shape)
print("DEV   SHAPE: ", df_dev.shape)

TRAIN SHAPE:  (14392, 3)
TEST  SHAPE:  (4066, 3)
DEV   SHAPE:  (2056, 3)


- We observe in sentencess that we have hashtag, no-ascii character, stopword , ... we have to clean data 

In [4]:
df_train.head(5)

Unnamed: 0,tweet_id,plain_text,xml_text
0,ID_1022420413882744832,Nearly half of #houses checked in #fire-stricken areas deemed #uninhabitable #GO #PrayForGreece #PrayForAthens #AthensFires ἞C἟7,Nearly half of #houses checked in #fire-stricken areas deemed #uninhabitable #GO #PrayForGreece #PrayForAthens #AthensFires ἞C἟7
1,ID_1021778661895294976,RT @anadoluagency: #Greece: Death toll from wildfires hits 74,RT @anadoluagency: #<COUNTRY>Greece</COUNTRY>: Death toll from wildfires hits 74
2,ID_1022015997740503042,When the essence of cooperation meets the sad reality of lifeThe IPA partner country offers financial aid to Greece to handle disaster @InterregIPACBC #Greecefires,When the essence of cooperation meets the sad reality of lifeThe IPA partner country offers financial aid to <COUNTRY>Greece</COUNTRY> to handle disaster @InterregIPACBC #Greecefires
3,ID_1022557424585240576,We are live from the Lureio Idrima the orphanage and nursing home operared by the nuns of the Holy Trinity Monastery that was destroyed by the fire in Neos Voutzas. Here too the scene is apocalyptic.,We are live from the Lureio Idrima the orphanage and nursing home operared by the nuns of the <HUMAN-MADE-POINT-OF-INTEREST>Holy Trinity Monastery</HUMAN-MADE-POINT-OF-INTEREST> that was destroyed by the fire in <NEIGHBORHOOD>Neos Voutzas</NEIGHBORHOOD>. Here too the scene is apocalyptic.
4,ID_1021749412639457280,RT @AP: Greek prime minister declares 3-day national mourning period for dozens killed by wildfires near Athens.,RT @AP: Greek prime minister declares 3-day national mourning period for dozens killed by wildfires near <CITY>Athens</CITY>.


In [5]:
df_train = Preprocess.remove_non_ascii(df_train, column_name='plain_text')
df_train = Preprocess.remove_non_ascii(df_train, column_name='xml_text')
df_dev   = Preprocess.remove_non_ascii(df_dev, column_name='plain_text')
df_dev   = Preprocess.remove_non_ascii(df_dev, column_name='xml_text')

In [6]:
df_train.head(5)

Unnamed: 0,tweet_id,plain_text,xml_text
0,ID_1022420413882744832,Nearly half of #houses checked in #fire-stricken areas deemed #uninhabitable #GO #PrayForGreece #PrayForAthens #AthensFires C7,Nearly half of #houses checked in #fire-stricken areas deemed #uninhabitable #GO #PrayForGreece #PrayForAthens #AthensFires C7
1,ID_1021778661895294976,RT @anadoluagency: #Greece: Death toll from wildfires hits 74,RT @anadoluagency: #<COUNTRY>Greece</COUNTRY>: Death toll from wildfires hits 74
2,ID_1022015997740503042,When the essence of cooperation meets the sad reality of lifeThe IPA partner country offers financial aid to Greece to handle disaster @InterregIPACBC #Greecefires,When the essence of cooperation meets the sad reality of lifeThe IPA partner country offers financial aid to <COUNTRY>Greece</COUNTRY> to handle disaster @InterregIPACBC #Greecefires
3,ID_1022557424585240576,We are live from the Lureio Idrima the orphanage and nursing home operared by the nuns of the Holy Trinity Monastery that was destroyed by the fire in Neos Voutzas. Here too the scene is apocalyptic.,We are live from the Lureio Idrima the orphanage and nursing home operared by the nuns of the <HUMAN-MADE-POINT-OF-INTEREST>Holy Trinity Monastery</HUMAN-MADE-POINT-OF-INTEREST> that was destroyed by the fire in <NEIGHBORHOOD>Neos Voutzas</NEIGHBORHOOD>. Here too the scene is apocalyptic.
4,ID_1021749412639457280,RT @AP: Greek prime minister declares 3-day national mourning period for dozens killed by wildfires near Athens.,RT @AP: Greek prime minister declares 3-day national mourning period for dozens killed by wildfires near <CITY>Athens</CITY>.


- Get a look of extracted tags and prepare list of entities

In [7]:
tags_list = pd.read_csv("../utils/tag_description.csv")
tags_list

Unnamed: 0,original_type,xml_tag
0,State,STATE
1,County,COUNTY
2,City/town,CITY
3,Road/street,ROAD
4,Island,ISLAND
5,Human-made Point-of-Interest,HUMAN-MADE-POINT-OF-INTEREST
6,Continent,CONTINENT
7,Neighborhood,NEIGHBORHOOD
8,Natural Point-of-Interest,NATURAL-POINT-OF-INTEREST
9,District,DISTRICT


In [8]:
tags_description = [
    "CONTINENT: A large continuous landmass on the Earth's surface",
    "COUNTRY: A nation or territory recognized as an independent state",
    "COUNTY: A geographical region within a country, often a subdivision of a state or province.",
    "DISTRICT: An administrative division within a city, county, or country.",
    "STATE: A larger administrative division within a country",
    "CITY: A large and significant urban area",
    "ROAD: A street, avenue, or highway that connects different locations.",
    "ISLAND: A landmass completely surrounded by water.",
    "NEIGHBORHOOD: A localized community within a city or town.",
    "HUMAN-MADE-POINT-OF-INTEREST: A landmark created by humans, such as monuments or buildings.",
    "NATURAL-POINT-OF-INTEREST: A location of natural significance, like mountains or rivers.",
]

### **Init model utils**

Init model and tokenizer

In [9]:
model     = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

### **Including Chain-Of-Thought in the prompt design**



In [10]:
messages = [
    {"role": "user", "content": "usr_msg1"},
    {"role": "assistant", "content": "asst_msg1"},
    {"role": "user", "content": "usr_msg2"},
    {"role": "assistant", "content": "asst_msg2"},
]
tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)

'<s> [INST] usr_msg1 [/INST] asst_msg1</s> [INST] usr_msg2 [/INST] asst_msg2</s>'

In [11]:
def formating_prompt(rule_set: List[str], input_str: str, label_str: str, tokenizer: PreTrainedTokenizerBase) -> torch.Tensor:
    rule_str = "\n".join(rule_set)

    # Message 1
    usr_msg1 = "You are given a user utterance that may contain Location Mention and you have to perform Location Mention Recognition" \
        "task (LMR). Given a microblogging posts during emergencies, the task is to recognize all location mentions. " \
        "You are also given a list of entity types representing Location type we want to recognise. " \
        "Your task is to detect and identify all instances of the supplied Location entity types in the user utterance. " \
        "The output must have the same content as the input. Only the tokens that match the Location entities in the " \
        "list should be enclosed within XML tags. The XML tag comes from the Location entities described in the list below. " \
        "For example, a city should be enclosed within <CITY></CITY> tags." \
        "Ensure that all entities are identified. Do not perform false identifications." \
        f"""\n\nList Of Entities\n{rule_str}"""\
        "\n\n" \
        "Are the instructions clear to you?"
    asst_msg1 = "Yes, the instructions are clear. I will identify and enclose within the corresponding XML tags, " \
        "all instances of the specified Location entity types in the user utterance. For example, " \
        "<CITY><Name of city></CITY>, <ROAD><A street, avenue, or highway that connects different locations.></ROAD>, etc. " \
        "leaving the rest of the user utterance unchanged."
    
    # Message 2
    usr_msg2 = "Florida Bahamas flooding shuts down train lines."
    asst_msg2 = "<STATE>Florida</STATE> <ISLAND>Bahamas</ISLAND> flooding shuts down train lines."
    
    # Message 3
    usr_msg3 = "Give a brief explanation of why your answer is correct."
    asst_msg3 = "I identified and enclosed within corresponding XML tags, all instances of the specified Location " \
        "entity types in the user utterance - The \"Florida\" state in U.S within the <STATE></STATE> tag, and " \
        "the Florida island \"Bahamas\" within the <ISLAND></ISLAND> tag. The rest of the user " \
        "utterance was left unchanged as it did not contain any other identified Location entities."
    
    # Message 4
    usr_msg4 = "Great! I am now going to give you another user utterance. Please detect Location entities in it " \
        "according to the previous instructions. Do not include an explanation in your answer."
    asst_msg4 = "Sure! Please give me the user utterance."

    # Encode in conversation
    messages = [
        {"role": "user", "content": usr_msg1},
        {"role": "assistant", "content": asst_msg1},
        {"role": "user", "content": usr_msg2},
        {"role": "assistant", "content": asst_msg2},
        {"role": "user", "content": usr_msg3},
        {"role": "assistant", "content": asst_msg3},
        {"role": "user", "content": usr_msg4},
        {"role": "assistant", "content": asst_msg4},
        {"role": "user", "content": input_str},
        {"role": "assistant", "content": label_str},
    ]

    return messages

In [12]:
def create_prompt_df(df: pd.DataFrame, rule_set: List[str], tokenizer: PreTrainedTokenizerBase) -> pd.DataFrame:
    prompt_list = []
    for _, row in df.iterrows():
        input_str = row['plain_text']
        label_str = row['xml_text']
        prompt_dict = formating_prompt(rule_set, input_str, label_str, tokenizer)
        prompt_list.append(prompt_dict)
    
    dataset = Dataset.from_dict({"prompt": prompt_list})
    dataset = dataset.map(lambda x: {"formatted_prompt": tokenizer.apply_chat_template(x["prompt"], tokenize=False, add_generation_prompt=False)})
    return dataset

### **Preparing model for fine-tunning**

In [13]:
dataset = {} 
dataset['train'] = create_prompt_df(df_train, tags_description, tokenizer)
dataset['eval']  = create_prompt_df(df_dev, tags_description, tokenizer)

Map:   0%|          | 0/14392 [00:00<?, ? examples/s]

Map:   0%|          | 0/2056 [00:00<?, ? examples/s]

- Data Collator

In [14]:
@dataclass
class CustomDataCollatorWithPadding:
    """
    Data collator that will dynamically pad the inputs received.

    Args:
        tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
            The tokenizer used for encoding the data.
        padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:

            - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
              sequence is provided).
            - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
              acceptable input length for the model if that argument is not provided.
            - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
        max_length (`int`, *optional*):
            Maximum length of the returned list and optionally padding length (see above).
        pad_to_multiple_of (`int`, *optional*):
            If set will pad the sequence to a multiple of the provided value.

            This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
            7.5 (Volta).
        return_tensors (`str`, *optional*, defaults to `"pt"`):
            The type of Tensor to return. Allowable values are "np", "pt" and "tf".
    """

    tokenizer: PreTrainedTokenizerBase
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    return_tensors: str = "pt"

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        batch = self.tokenizer.pad(
            features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors=self.return_tensors,
        )
        labels = batch["input_ids"].clone()
        
        # Set loss mask for all pad tokens
        labels[labels == self.tokenizer.pad_token_id] = -100
        
        # Compute loss mask for appropriate tokens only
        for i in range(batch['input_ids'].shape[0]):
            
            # Decode the training input
            text_content = self.tokenizer.decode(batch['input_ids'][i][1:])  # slicing from [1:] is important because tokenizer adds bos token
            
            # Extract substrings for prompt text in the training input
            # The training input ends at the last user msg ending in [/INST]
            prompt_gen_boundary = text_content.rfind("[/INST]") + len("[/INST]")
            prompt_text = text_content[:prompt_gen_boundary]
            
            # print(f"""PROMPT TEXT:\n{prompt_text}""")
            
            # retokenize the prompt text only
            prompt_text_tokenized = self.tokenizer(
                prompt_text,
                return_overflowing_tokens=False,
                return_length=False,
            )
            # compute index where prompt text ends in the training input
            prompt_tok_idx = len(prompt_text_tokenized['input_ids'])
            
            # Set loss mask for all tokens in prompt text
            labels[i][range(prompt_tok_idx)] = -100
            
            # print("================DEBUGGING INFORMATION===============")
            # for idx, tok in enumerate(labels[i]):
            #     token_id = batch['input_ids'][i][idx]
            #     decoded_token_id = self.tokenizer.decode(batch['input_ids'][i][idx])
            #     print(f"""TOKID: {token_id} | LABEL: {tok} || DECODED: {decoded_token_id}""")
                    
        batch["labels"] = labels
        return batch

- Model Building

In [16]:
pd.concat([df_train, df_dev])['plain_text'].apply(len).max()


945

In [15]:
# Args
tokenizer.padding_side = 'right'
max_seq_length = pd.concat([df_train, df_dev])['plain_text'].apply(len).max()
training_arguments = SFTConfig(
    dataset_text_field="formatted_prompt",
    output_dir="Mistral-7B-Output",
    max_seq_length=max_seq_length,
    packing=False,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
)

# Trainer
trainer = SFTTrainer(
    model=model.to(torch.device('cpu')),
    train_dataset=dataset['train'],
    eval_dataset=dataset['eval'],
    
    tokenizer=tokenizer,
    args=training_arguments,
    data_collator=CustomDataCollatorWithPadding(
        tokenizer=tokenizer, 
        padding="longest", 
        max_length=max_seq_length, 
        return_tensors="pt"
    )
)

Map:   0%|          | 0/14392 [00:00<?, ? examples/s]

Map:   0%|          | 0/2056 [00:00<?, ? examples/s]

RuntimeError: MPS backend out of memory (MPS allocated: 18.01 GB, other allocations: 384.00 KB, max allowed: 18.13 GB). Tried to allocate 224.00 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

#### **Model Training**

In [None]:
trainer.train()

#### **Make Inference**