<div style="padding: 0.5em; background-color: #1876d1; color: #fff; font-weight: bold; font-size: 1.4em;">
    Fine-Tuning BERT for Location Mention Recognition
</div>

This notebook demonstrates the process of fine-tuning a BERT model to recognize and categorize location mentions in text using the IDRISI dataset. The task at hand is a type of Named Entity Recognition (NER) where the goal is to identify and classify location names, such as countries, cities, or landmarks, within a given text.

We utilize the BILOU (Begin, Inside, Last, Outside, Unit) labeling scheme, which provides detailed annotations of entity boundaries. Fine-tuning BERT with these structured labels allows the model to leverage its deep contextual understanding to perform highly accurate token classification, essential for detecting location mentions in diverse textual data.

The notebook is structured as follows:
1. **Setup and Installation**: Install and import the necessary libraries.
2. **Data Ingestion and Preprocessing**: Load the IDRISI dataset and prepare it for modeling, including tokenization and label mapping.
3. **Modeling Preparation**: Create custom datasets, define label mappings, and set up the BERT model for token classification.
4. **Fine-Tuning**: Train the BERT model on the labeled data, optimizing for accuracy in location mention recognition.
5. **Evaluation**: Assess the performance of the fine-tuned model using the Word Error Rate Metric.

### **Setup & Utils**

In [48]:
#!pip install transformers jiwer pandas accelerate -U

In [49]:
# basic
import sys, os, re, torch, werpy, jiwer
from collections import Counter
import pandas as pd
from tqdm import tqdm
import numpy as np

# ml commons
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertForTokenClassification, Trainer, TrainingArguments, DataCollatorForTokenClassification
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

# devie septup
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# utils setup
current_directory = os.getcwd()
root_directory = os.path.abspath(os.path.join(current_directory, os.pardir))
sys.path.append(root_directory)

# custom utils
from utils.io import Predictions
from utils.metrics import LMR_Metrics
from utils.io import LMR_BILOU_Scrapper, LMR_JSON_Scrapper
from utils.preprocessing import Preprocess
from utils.stratify import MultiLabelNERStratify

### **Helpers**

In [50]:
# def ingest_idrisi_data(bilou_base_dir='/kaggle/input/idrisi-location-mention/LMR/data/EN/gold-random-bilou/'):
#     sentences, labels = [], []
#     for root, dirs, files in os.walk(bilou_base_dir):
#         for file in files:
#             if file.endswith('.txt'):
#                 file_path = os.path.join(root, file)
#                 with open(file_path, 'r') as f:
#                     current_sentence, current_labels = [], []
#                     for line in f:
#                         word_label = line.strip().split()
#                         if len(word_label) == 2:
#                             word, label = word_label
#                             current_sentence.append(word)
#                             current_labels.append(label)
#                         elif len(current_sentence) > 0:
#                             sentences.append(' '.join(current_sentence))
#                             labels.append(','.join(current_labels))
#                             current_sentence, current_labels = [], []
#                     if len(current_sentence) > 0:
#                         sentences.append(' '.join(current_sentence))
#                         labels.append(','.join(current_labels))
#     return pd.DataFrame({'sentence': sentences, 'word_labels': labels})

def infer_on_sentences(sentences, model, tokenizer, max_len=300, with_extra=False):
    # Put the model in evaluation mode
    model.eval()
    
    results = []
    extra_results = []
    
    for sentence in tqdm(sentences):
        # Tokenize the sentence and prepare input for the model
        tokenized_sentence = tokenizer(
            sentence.split(),
            is_split_into_words=True,
            return_offsets_mapping=False,
            padding='max_length',
            truncation=True,
            max_length=max_len,
            return_tensors="pt"
        )
        
        # Move tensors to the correct device
        input_ids = tokenized_sentence['input_ids'].to(device)
        attention_mask = tokenized_sentence['attention_mask'].to(device)
        
        # Get predictions
        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        
        logits = outputs.logits
        predictions = torch.argmax(logits, dim=2)  # Get the index of the highest logit for each token
        
        # Convert predictions to labels
        pred_labels = [id2label[pred.item()] for pred in predictions[0]]
        
        # Get the original tokens from input_ids
        tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
        
        # Filter out tokens with the 'O' label and concatenate them
        filtered_tokens = [
            token for token, label in zip(tokens, pred_labels)
            if label != 'O' and token not in ['[CLS]', '[SEP]', '[PAD]']
        ]
        filtered_labels = [
            label for token, label in zip(tokens, pred_labels)
            if label != 'O' and token not in ['[CLS]', '[SEP]', '[PAD]']
        ]
        
        results.append(" ".join(filtered_tokens))
        extra_results.append(filtered_labels)

    if with_extra:
        return results, extra_results

    return results

def calculate_performance_metric(df, col1='location', col2='prediction'):

    # Function to calculate WER for each row
    def calculate_wer(row):
        return jiwer.wer(str(row[col1]), str(row[col2]))

    # Calculate WER for each row
    df['WER'] = df.apply(calculate_wer, axis=1)

    # Calculate the average WER
    average_wer = df['WER'].mean()

    return df, average_wer

def clean_text(text):
    # Define a dictionary of replacements
    replacements = {
        ",": " ",
        "@": "",
        ".": "",
        ";": "",
        "-": " ",
        "_": "",
        "#": "",
        "##": ""
    }
    
    cleaned_text = text
    for k, v in replacements.items():
        cleaned_text = cleaned_text.replace(k, v)

    return cleaned_text

def clean_prediction(row, raw_prediction_col='prediction_raw'):
    prediction = row[raw_prediction_col]
    prediction = prediction.replace(" ##", "")
    if prediction.startswith("##"):
        prediction = " ".join(prediction.split()[1:])

    cleaned_text = clean_text(row['text'])
    lower_upper_map = {k.lower(): k for k in cleaned_text.split()}

    for k, v in lower_upper_map.items():
        prediction = prediction.replace(k, v)

    replacements = {
        "U S .": "",
        "L . A .": "L.A.",
        "P R . P R .": "P.R.",
        "N C . N C": "N.C.",
        "u . s .": "U.S.",
        "s . c .": "S.C.",
        "n . c . n . c": "N.C.",
        "n . c .": "N.C.",
        "d . c .": "D.C.",
        "n c . n c": "N.C.",
        '. r . p . r .': "P.R.",
        "u s .": "U.S.",

        " sc": "",
        " St": "",
        " -": "",
        " .": "",
        " _": "",
    }
    cleaned_prediction = prediction
    for k, v in replacements.items():
        cleaned_prediction = cleaned_prediction.replace(k, v)

    prediction_words = cleaned_prediction.split()
    if len(prediction_words) > 5:
        cleaned_prediction = Counter(cleaned_prediction.split()).most_common(1)[0][0]

    if len(set(prediction_words)) == 1:
        cleaned_prediction = prediction_words[0] 

    return cleaned_prediction

In [51]:
# Helper 1
def tokenize_and_preserve_labels(sentence, text_labels, tokenizer):
    tokenized_sentence, labels = [], []
    for word, label in zip(sentence.split(), text_labels.split(" ")):
        tokenized_word = tokenizer.tokenize(word)
        n_subwords = len(tokenized_word)
        tokenized_sentence.extend(tokenized_word)
        labels.extend([label] * n_subwords)
    return tokenized_sentence, labels

# Helper 2
class CustomDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_len):
        self.data = dataframe
        self.tokenizer = tokenizer
        self.max_len = max_len
        
    def __getitem__(self, index):
        sentence = self.data.words[index]  
        word_labels = self.data.labels[index]  
        tokenized_sentence, labels = tokenize_and_preserve_labels(sentence, word_labels, self.tokenizer)
        
        tokenized_sentence = ["[CLS]"] + tokenized_sentence + ["[SEP]"]
        labels.insert(0, "O")
        labels.insert(-1, "O")

        if len(tokenized_sentence) > self.max_len:
            tokenized_sentence = tokenized_sentence[:self.max_len]
            labels = labels[:self.max_len]
        else:
            tokenized_sentence += ['[PAD]'] * (self.max_len - len(tokenized_sentence))
            labels += ["O"] * (self.max_len - len(labels))

        attn_mask = [1 if tok != '[PAD]' else 0 for tok in tokenized_sentence]
        ids = self.tokenizer.convert_tokens_to_ids(tokenized_sentence)
        label_ids = [label2id[label] for label in labels]
        
        return {
            'input_ids': torch.tensor(ids, dtype=torch.long),
            'attention_mask': torch.tensor(attn_mask, dtype=torch.long),
            'labels': torch.tensor(label_ids, dtype=torch.long)
        }
    
    def __len__(self):
        return len(self.data)
    
# Helper 3
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)

    # For WER Metrics
    true_loc, pred_loc = [], []
    for i in range(len(labels)):
        label_decoded = [id2label[pred] for pred in labels[i]]
        pred_decoded  = [id2label[pred] for pred in preds[i]]
        filtered_pred  = " ".join([word for word in pred_decoded if word not in ['O', '[CLS]', '[SEP]', '[PAD]']])
        filtered_label = " ".join([word for word in label_decoded if word not in ['O', '[CLS]', '[SEP]', '[PAD]']])
        true_loc.append(filtered_label)
        pred_loc.append(filtered_pred)
    wer_scores  = werpy.wers(true_loc, pred_loc)
    average_wer = sum(wer_scores) / len(wer_scores)
    # print(average_wer)
    # print(true_loc)
    # print(pred_loc)

    precision, recall, f1, _ = precision_recall_fscore_support(labels.flatten(), preds.flatten(), average='weighted')
    acc = accuracy_score(labels.flatten(), preds.flatten())
    return {
        'accuracy'  : acc,
        'precision' : precision,
        'recall'    : recall,
        'f1'        : f1,
        'wer'       : average_wer
    }

### **Data preparation**

In [52]:
train_dfs = []
dev_dfs   = []
path_dfs  = "../data/self_scrapped/raw"
for filename in os.listdir(path_dfs):
    if filename.endswith(".csv"):
        file_path = os.path.join(path_dfs, filename)
        if filename.startswith("train"):
            df = pd.read_csv(file_path)
            train_dfs.append(df)
        elif filename.startswith("dev"):
            df = pd.read_csv(file_path)
            dev_dfs.append(df)

df_train = pd.concat(train_dfs, ignore_index=True) if train_dfs else pd.DataFrame()
df_dev   = pd.concat(dev_dfs, ignore_index=True) if dev_dfs else pd.DataFrame()

df_train = pd.concat([df_train, df_dev])
print("TRAIN SHAPE: ", df_train.shape)

TRAIN SHAPE:  (16448, 3)


- Use augmented df

In [53]:
df = pd.read_csv('../data/provided/TrainEncoded.csv')

def parse_location_mentions(location_mentions):
    location_dict = {}
    if pd.notna(location_mentions):
        parts = location_mentions.split(' * ')
        for part in parts:
            location, loc_type = part.split('=>')
            location_dict[location.strip()] = loc_type.strip()
    return location_dict

location_type_dict = {}
for location in df_train['location_mentions'].dropna():
    location_type_dict.update(parse_location_mentions(location))

def label_location(row, location_type_dict):
    location = row['location']
    labeled_locations = []
    words = location.split()
    
    while words:
        for i in range(len(words), 0, -1):
            sub_location = ' '.join(words[:i]).strip()
            if sub_location in location_type_dict:
                labeled_locations.append(f"{sub_location}=>{location_type_dict[sub_location]}")
                words = words[i:]
                break
        else:
            words = words[:-1]
    if labeled_locations:
        return ' * '.join(labeled_locations)
    else:
        return None

df['location_mentions'] = df.apply(lambda row: label_location(row, location_type_dict), axis=1)

- Preprocessing pipline

In [54]:
df = Preprocess.remove_non_ascii(df, column_name='text')
df = Preprocess.remove_usertag(df, column_name='text')
df = Preprocess.reformat_hashtag(df, column_name='text')
df = Preprocess.remove_prefix(df, df_type="train", text_column='text')
df = Preprocess.reformat_useless_char(df, column_name='text')
df.head(10)

Unnamed: 0,tweet_id,text,location,location_mentions
0,ID_1001136212718088192,EllicottCity is known for its vibrant art scene.,EllicottCity,EllicottCity=>CITY
1,ID_1001136696589631488,Flash floods struck a Maryland city on Sunday ...,Maryland,Maryland=>STATE
2,ID_1001136950345109504,State of emergency declared for Maryland flood...,Maryland,Maryland=>STATE
3,ID_1001137334056833024,Other parts of Maryland also saw significant d...,Baltimore Maryland,Baltimore=>CITY * Maryland=>STATE
4,ID_1001138374923579392,Catastrophic Flooding Slams Ellicott City Mary...,Ellicott City Maryland,Ellicott City=>CITY * Maryland=>STATE
5,ID_1001138377717157888,1 missing after flash FLOODING devastates Elli...,Ellicott City Maryland,Ellicott City=>CITY * Maryland=>STATE
6,ID_1001139323075416064,The scenic spots in Ellicott City Maryland are...,Ellicott City Maryland,Ellicott City=>CITY * Maryland=>STATE
7,ID_1001140017207459840,Maryland has a variety of historical landmarks.,Maryland,Maryland=>STATE
8,ID_1001140276377935872,The local food markets in Maryland are a feast...,Maryland,Maryland=>STATE
9,ID_1001140804503601152,Baltimore is a popular destination for both re...,Baltimore,Baltimore=>CITY


In [55]:
lemma_path = "../data/new/train.encoded.lemma.csv"
if not os.path.exists(lemma_path):
    df_ = Preprocess.remove_stop_words(df, column_name='text', new_col="text_transformed", transformation=[
        "tokenize", "lemma", "lower"
    ], save_in=lemma_path)
else:
    df_ = pd.read_csv(lemma_path)

# Subtitution
df = df_.drop(columns=['text'])
df = df.rename(columns={'text_transformed': 'text'})
df = df.dropna(subset=['text'])

In [56]:
print(df.shape)
print(df.isnull().sum())

(43335, 4)
tweet_id                0
location                0
location_mentions    1831
text                    0
dtype: int64


- Train dev split

In [57]:
df_idx, ner_classes = MultiLabelNERStratify.process_location_mentions(df)
train_idx, test_idx, train_label_freq, test_label_freq = MultiLabelNERStratify.stratify_train_test_split_multi_label(
    df_idx.tweet_id, 
    np.vstack(df_idx.location_array_freq.values), 
    test_size=0.2
)

# Filter the original DataFrame based on the tweet_id column
train_idx_list = train_idx.tolist() if hasattr(train_idx, 'tolist') else list(train_idx)
test_idx_list = test_idx.tolist() if hasattr(test_idx, 'tolist') else list(test_idx)
df_train = df[df['tweet_id'].isin(train_idx_list)]
df_dev   = df[df['tweet_id'].isin(test_idx_list)]

# print repartition
print("TRAIN SHAPE: ", df_train.shape)
print("DEV SHAPE: ", df_dev.shape)

TRAIN SHAPE:  (33984, 4)
DEV SHAPE:  (9351, 4)


In [58]:
df_tag_train = Preprocess.build_bilou_encoding(df_train, text_col="text", save_in="../data/new/train.encoded.bilou.tag.csv")
df_tag_dev   = Preprocess.build_bilou_encoding(df_dev, text_col="text", save_in="../data/new/dev.encoded.bilou.tag.csv")

df_tag_train.head(5)

Unnamed: 0,sentence_id,words,labels
0,ID_1001136212718088192,ellicottcity,U-CITY
1,ID_1001136212718088192,be,O
2,ID_1001136212718088192,know,O
3,ID_1001136212718088192,for,O
4,ID_1001136212718088192,its,O


In [61]:
# Train
train_dataset = df_tag_train.groupby('sentence_id').agg({
    'words': lambda x: ' '.join(x),
    'labels': lambda x: ' '.join(x)
}).reset_index()
train_dataset = train_dataset.drop(columns=['sentence_id'])
train_dataset.to_csv('../data/new/kaggle/train_dataset.csv')

# Dev
test_dataset = df_tag_dev.groupby('sentence_id').agg({
    'words': lambda x: ' '.join(x),
    'labels': lambda x: ' '.join(x)
}).reset_index()
test_dataset = test_dataset.drop(columns=['sentence_id'])
test_dataset.to_csv('../data/new/kaggle/dev_dataset.csv')

# Full
data = pd.concat([train_dataset, test_dataset])

In [None]:
# temp for debug -- should be deleted for prod training
# test_dataset = test_dataset.sample(frac=0.001, random_state=200).reset_index(drop=True)
# train_dataset = train_dataset.sample(frac=0.001, random_state=200).reset_index(drop=True)

In [None]:
test_dataset

Unnamed: 0,words,labels
0,flash flood strike a maryland city on sunday w...,O O O O U-STATE O O O O O O O O O O O O
1,state of emergency declare for maryland floodi...,O O O O O U-STATE O O
2,maryland have a variety of historical landmark,U-STATE O O O O O O
3,the local food market in maryland be a feast f...,O O O O O U-STATE O O O O O O
4,maryland be know for its beautiful park and ga...,U-STATE O O O O O O O O
...,...,...
9345,mexico be know for its picturesque riverfront,U-COUNTRY O O O O O O
9346,cuban doctor treat mexicos earthquake victim,O O O O O O
9347,mexico be know for its picturesque landscape,U-COUNTRY O O O O O O
9348,mexico earthquake relief font bundle,U-COUNTRY O O O O


In [None]:
data.head()

Unnamed: 0,words,labels
0,ellicottcity be know for its vibrant art scene,U-CITY O O O O O O O
1,other part of maryland also see significant da...,O O O U-STATE O O O O O O O O O U-CITY O O O O...
2,catastrophic flooding slam ellicott city maryl...,O O O B-CITY L-CITY U-STATE O O O O O O O
3,1 miss after flash flooding devastate ellicott...,O O O O O O B-CITY L-CITY U-STATE O
4,the scenic spot in ellicott city maryland be p...,O O O O B-CITY L-CITY U-STATE O O O O


### **Modeling preparation**

- **Prepare custom label mappings**: 

<div style="padding-left: 2.5em;">Before fine-tuning, it’s essential to map the location mention labels from the BILOU format to a format that BERT can understand. This involves converting categorical labels (e.g., `B-CITY`,`B-COUNTY`, ...) into integer IDs, which the model will use during training. This mapping is critical because BERT outputs logits for each token, which are then converted back to these labels.</div>

In [None]:
# Extract unique tags from word labels
tags = set(" ".join(data.labels).split(' '))

# Create label to ID and ID to label mappings
label2id = {k: v for v, k in enumerate(tags)}
id2label = {v: k for v, k in enumerate(tags)}

# Get a look of tags
#tags

### **Setup the model and tokenizer**

- **Pretrained model for huggingface**: 

<div style="padding-left: 2.5em;">We retrieve the tokenizer and the model from Huggingface's library of pre-trained models. This allows us to leverage a model that has already been fine-tuned for a specific task, such as Named Entity Recognition (NER). The tokenizer helps preprocess the input text by converting it into a format that the model can interpret, while the model is used to make predictions based on this input.</div>

In [None]:
tokenizer = BertTokenizer.from_pretrained("bert-large-uncased")
model = BertForTokenClassification.from_pretrained(
    "bert-large-uncased",
    num_labels=len(id2label),
    id2label=id2label,
    label2id=label2id
)
model.to(device)

Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-large-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 1024, padding_idx=0)
      (position_embeddings): Embedding(512, 1024)
      (token_type_embeddings): Embedding(2, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-23): 24 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): LayerNorm((1024

- **DataCollator**: 

<div style="padding-left: 2.5em;">
A custom dataset class is created to handle the input data, applying tokenization and ensuring that sequences are properly padded or truncated to fit the model’s expected input size. The `DataCollatorForTokenClassification` from the Hugging Face `transformers` library is used to dynamically pad batches during training, making the process efficient and preventing data leakage between samples.
</div>

In [None]:
data_collator = DataCollatorForTokenClassification(tokenizer)

- Create custom datasets for training and testing

In [None]:
data['labels'].apply(lambda x: len(x.split(" "))).max()

58

In [None]:
MAX_LEN = 100

training_set = CustomDataset(train_dataset, tokenizer, MAX_LEN)
testing_set = CustomDataset(test_dataset, tokenizer, MAX_LEN)

- Define training parameters

In [None]:
TRAIN_BATCH_SIZE = 16
VALID_BATCH_SIZE = 8
EPOCHS = 2

- Setup trainer

In [None]:
torch.cuda.empty_cache()

In [None]:
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=TRAIN_BATCH_SIZE,
    per_device_eval_batch_size=VALID_BATCH_SIZE,
    warmup_steps=25,
    weight_decay=0.001,
    logging_dir='./logs',
    logging_steps=10,
    eval_strategy="steps",
    eval_steps=25,
    save_steps=50,
    save_total_limit=2,
    gradient_accumulation_steps=4,  # Accumulate gradients for larger effective batch size
    fp16=False,  # Enable mixed precision training for faster computation
    report_to=["none"] #set this to true if you have a WANDB API key
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=training_set,
    eval_dataset=testing_set,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

### **Fine-tuning**

In [47]:
trainer.train()

  0%|          | 0/1062 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
***

# Measuring average WER

Now that the model is trained, let's make inference on train data to evaluate against the custom metric.

In [None]:
# train = pd.read_csv("/kaggle/input/dattttt/dat/Train.csv")
# 
# train = train[~train['text'].isna()]

In [None]:
# implement in batches later
# train_predictions = infer_on_sentences(train.text.to_list(), trainer.model, tokenizer)
# train['prediction_raw'] = train_predictions

100%|██████████| 16448/16448 [15:23<00:00, 17.80it/s]


In [None]:
# train.to_csv('news.csv', index=False)

In [None]:
# df, average_wer = calculate_performance_metric(train, col2='prediction_raw')
# average_wer

1.28802675855813

Can we do better?

Here we will perform some post-inference cleaning,using the `clean_prediction` function defined in the helpers section. Nothing fancy, just a bunch of heuristics.

In [None]:
# train['prediction_clean'] = train.apply(clean_prediction, axis=1)

In [None]:
# df, average_wer = calculate_performance_metric(train, col2='prediction_clean')
# average_wer

0.5019827050651466

# Submission

In [None]:
# test = pd.read_csv("/kaggle/input/dattttt/dat/Test.csv")

In [None]:
# implement in batches later
# test_predictions = infer_on_sentences(test.text.to_list(), model, tokenizer)

100%|██████████| 2942/2942 [02:45<00:00, 17.81it/s]


In [None]:
# test['prediction_raw'] = test_predictions
# test['prediction'] = test.apply(clean_prediction, axis=1)
# test['prediction'] = test['prediction'].replace("", " ")

In [None]:
# test[['tweet_id', 'prediction']].to_csv("bert-large-uncased-fine-tuned-1-epoch+huristic-cleaning.csv", index=False)