<div style="padding: 0.5em; background-color: #1876d1; color: #fff; font-weight: bold; font-size: 1.4em;">
    Fine-Tuning BERT for Location Mention Recognition
</div>

This notebook demonstrates the process of fine-tuning a BERT model to recognize and categorize location mentions in text using the IDRISI dataset. The task at hand is a type of Named Entity Recognition (NER) where the goal is to identify and classify location names, such as countries, cities, or landmarks, within a given text.

We utilize the BILOU (Begin, Inside, Last, Outside, Unit) labeling scheme, which provides detailed annotations of entity boundaries. Fine-tuning BERT with these structured labels allows the model to leverage its deep contextual understanding to perform highly accurate token classification, essential for detecting location mentions in diverse textual data.

The notebook is structured as follows:
1. **Setup and Installation**: Install and import the necessary libraries.
2. **Data Ingestion and Preprocessing**: Load the IDRISI dataset and prepare it for modeling, including tokenization and label mapping.
3. **Modeling Preparation**: Create custom datasets, define label mappings, and set up the BERT model for token classification.
4. **Fine-Tuning**: Train the BERT model on the labeled data, optimizing for accuracy in location mention recognition.
5. **Evaluation**: Assess the performance of the fine-tuned model using the Word Error Rate Metric.

### **Setup & Utils**

In [3]:
!pip install transformers jiwer werpy wandb pandas accelerate -U

Collecting transformers
  Downloading transformers-4.44.2-py3-none-any.whl.metadata (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.7/43.7 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting jiwer
  Downloading jiwer-3.0.4-py3-none-any.whl.metadata (2.6 kB)
Collecting werpy
  Downloading werpy-2.1.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (14 kB)
Collecting wandb
  Downloading wandb-0.17.9-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.6 kB)
Collecting accelerate
  Downloading accelerate-0.34.2-py3-none-any.whl.metadata (19 kB)
Collecting rapidfuzz<4,>=3 (from jiwer)
  Downloading rapidfuzz-3.9.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Downloading transformers-4.44.2-py3-none-any.whl (9.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.5/9.5 MB[0m [31m61.6 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m


In [4]:
# basic
import sys, os, re, torch, werpy, jiwer, wandb
from collections import Counter
import pandas as pd
from tqdm import tqdm
import numpy as np
from typing import Literal

# ml commons
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertForTokenClassification, Trainer, TrainingArguments, DataCollatorForTokenClassification
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

# devie septup
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# utils setup
current_directory = os.getcwd()
root_directory = os.path.abspath(os.path.join(current_directory, os.pardir))
BASE_PATH = "/kaggle/input/lmr-generated"
sys.path.append(BASE_PATH+"utils")
os.environ["WANDB_PROJECT"] = "fine-tune-bert-for-lmr-2ep-t16v8-fullDF.ipynb"
os.environ["WANDB_LOG_MODEL"] = "checkpoint"
os.environ['CUDA_LAUNCH_BLOCKING']="1"

In [5]:
# custom utils
#from utils.io import Predictions
#from utils.metrics import LMR_Metrics
#from utils.io import LMR_BILOU_Scrapper, LMR_JSON_Scrapper
#from utils.preprocessing import Preprocess
#from utils.stratify import MultiLabelNERStratify

### **Helpers**

In [70]:
# Helper 1
def tokenize_and_preserve_labels(sentence, text_labels, tokenizer):
    tokenized_sentence, labels = [], []
    for word, label in zip(sentence.split(), text_labels.split(" ")):
        tokenized_word = tokenizer.tokenize(word)
        n_subwords = len(tokenized_word)
        tokenized_sentence.extend(tokenized_word)
        labels.extend([label] * n_subwords)
    return tokenized_sentence, labels

# Helper 2
class CustomDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_len):
        self.data = dataframe
        self.tokenizer = tokenizer
        self.max_len = max_len
        
    def __getitem__(self, index):
        sentence = self.data.words[index]  
        word_labels = self.data.labels[index]  
        tokenized_sentence, labels = tokenize_and_preserve_labels(sentence, word_labels, self.tokenizer)
        
        tokenized_sentence = ["[CLS]"] + tokenized_sentence + ["[SEP]"]
        labels.insert(0, "O")
        labels.insert(-1, "O")

        if len(tokenized_sentence) > self.max_len:
            tokenized_sentence = tokenized_sentence[:self.max_len]
            labels = labels[:self.max_len]
        else:
            tokenized_sentence += ['[PAD]'] * (self.max_len - len(tokenized_sentence))
            labels += ["O"] * (self.max_len - len(labels))

        attn_mask = [1 if tok != '[PAD]' else 0 for tok in tokenized_sentence]
        ids = self.tokenizer.convert_tokens_to_ids(tokenized_sentence)
        label_ids = [label2id[label] for label in labels]
        
        return {
            'input_ids': torch.tensor(ids, dtype=torch.long),
            'attention_mask': torch.tensor(attn_mask, dtype=torch.long),
            'labels': torch.tensor(label_ids, dtype=torch.long)
        }
    
    def __len__(self):
        return len(self.data)
    
# Helper 3
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)

    # For WER Metrics
    true_loc, pred_loc = [], []
    for i in range(len(labels)):
        label_decoded = [id2label[pred] for pred in labels[i]]
        pred_decoded  = [id2label[pred] for pred in preds[i]]
        filtered_pred  = " ".join([word for word in pred_decoded if word not in ['O', '[CLS]', '[SEP]', '[PAD]']])
        filtered_label = " ".join([word for word in label_decoded if word not in ['O', '[CLS]', '[SEP]', '[PAD]']])
        true_loc.append(filtered_label if filtered_label != "" else "Empty")
        pred_loc.append(filtered_pred if filtered_pred != "" else "Empty")
    wer_scores  = werpy.wers(true_loc, pred_loc)
    average_wer = sum(wer_scores) / len(wer_scores)
    # print(average_wer)
    # print(true_loc)
    # print(pred_loc)

    precision, recall, f1, _ = precision_recall_fscore_support(labels.flatten(), preds.flatten(), average='weighted')
    acc = accuracy_score(labels.flatten(), preds.flatten())
    return {
        'accuracy'  : acc,
        'precision' : precision,
        'recall'    : recall,
        'f1'        : f1,
        'wer'       : average_wer
    }

# Helper 4
def make_infererence(sentences, model, tokenizer, max_len=100, with_extra=False):
    model.eval()
    results = []
    extra_results = []
    
    for sentence in tqdm(sentences):
        tokenized_sentence = tokenizer(
            sentence.split(),
            is_split_into_words=True,
            return_offsets_mapping=False,
            padding='max_length',
            truncation=True,
            max_length=max_len,
            return_tensors="pt"
        )
        
        input_ids = tokenized_sentence['input_ids'].to(device)
        attention_mask = tokenized_sentence['attention_mask'].to(device)        
        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        
        logits = outputs.logits
        predictions = torch.argmax(logits, dim=2)  # Get the index of the highest logit for each token
        
        pred_labels = [id2label[pred.item()] for pred in predictions[0]]
        tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
        filtered_tokens = [
            token for token, label in zip(tokens, pred_labels)
            if label != 'O' and token not in ['[CLS]', '[SEP]', '[PAD]']
        ]
        filtered_labels = [
            label for token, label in zip(tokens, pred_labels)
            if label != 'O' and token not in ['[CLS]', '[SEP]', '[PAD]']
        ]
        
        results.append(" ".join(filtered_tokens))
        extra_results.append(filtered_labels)
        
    if with_extra:
        return results, extra_results
    return results

# Helper 5
def compute_wer_eval(df, col1='location', col2='prediction'):
    def calculate_wer(row):
        return jiwer.wer(str(row[col1]), str(row[col2]))
    df['WER'] = df.apply(calculate_wer, axis=1)
    average_wer = df['WER'].mean()
    return df, average_wer

In [71]:
def find_indices(text, locations):
    try:
        if locations.strip() == '':
            return ' '
        words = locations.split()
        
        combinations = []
        for length in range(1, 4):
            for i in range(len(words) - length + 1):
                combination = ' '.join(words[i:i + length])
                combinations.append(combination)
        indices = []
        for comb in combinations:
            for match in re.finditer(re.escape(comb), text, re.IGNORECASE):
                indices.append((match.start(), match.end(), comb))

        # keep only indices with the longest match. 
        indices = sorted(indices, key=lambda x: len(x[2]), reverse=True)
        # drop duplicated start indices 
        indices = [indices[0]] + [x for i, x in enumerate(indices[1:], 1) if x[0] not in [y[0] for y in indices[:i]]]
        # drop indices that are contained in other indices
        indices = [x for i, x in enumerate(indices) if not any(x[0] >= y[0] and x[1] <= y[1] for y in indices[:i] + indices[i+1:])]

        # return group of words corresponding to the indices in text 
        words_from_text = []
        for start, end, comb in indices:
            words_from_text.append(text[start:end])
    except IndexError: 
        words_from_text = find_substring_indices(text, locations)

    return " ".join(sorted(set(words_from_text)))

def find_substring_indices(text, locations):
    # Generate all possible substrings of the word
    substrings = [locations[i:j] for i in range(len(locations)) for j in range(i + 1, len(locations) + 1)]
    indices = []
    for substring in substrings:
        for match in re.finditer(re.escape(substring), text, re.IGNORECASE):
            indices.append((match.start(), match.end(), substring))

    # keep only indices with the longest match.
    indices = sorted(indices, key=lambda x: len(x[2]), reverse=True)

    # drop duplicated start indices
    indices = [indices[0]] + [x for i, x in enumerate(indices[1:], 1) if x[0] not in [y[0] for y in indices[:i]]]

    # drop indices that are contained in other indices
    indices = [x for i, x in enumerate(indices) if not any(x[0] >= y[0] and x[1] <= y[1] for y in indices[:i] + indices[i+1:])]

    # keep indices of words long of at least 2 characters
    indices = [x for x in indices if len(x[2]) >= 3]

    # words_from_text 
    words_from_text = []
    for start, end, substring in indices:
        words_from_text.append(substring)

    # sort words_from_text
    words_from_text = sorted(words_from_text)

    return words_from_text

def heuristic_postprocess_1(row):
    _id    = row['tweet_id']
    text  = row['raw_prediction']
    tweet = row['text']
    
    
    # 1 - Clean Special char
    replacements = {
        " ##": "",
        "##": "",
        ",": "",
        "U . S .": "U.S.",
        "U . S": "U.S.",
        "U S": "U.S.",
        "L . A .": "L.A.",
        "L . A": "L.A.",
        "L A": "L.A.",
        "P . R .": "P.R.",
        "P . R": "P.R.",
        "P R": "P.R.",
        "N . C .": "N.C.",
        "N . N": "N.C.",
        "N C": "N.C.",
        "D . C .": "D.C.",
        "D . C": "D.C.",
        "D C": "D.C."
    }
    for word, replacement in replacements.items():
        text = text.replace(word, replacement)
     
    #"""
    # 2 - Special Replace
    text = re.sub(r'\bM\b', 'Md.', text)
    text = re.sub(r'\bElliot\b', '', text)
    text = re.sub(r'\bMat\b', 'Matti', text)
    text = re.sub(r'\bSD\b', 'SDMA', text)
    text = re.sub(r'\bZ\b', 'Zimba', text)
    text = re.sub(r'\btt\b', 'Hutt', text)
    text = re.sub(r'\bbe\b', 'Brooklyn', text)
    text = re.sub(r'\bly\b', 'welly', text)
    text = re.sub(r'\bgree\b', 'greece', text)
    text = re.sub(r'\bAt\b', 'Attica', text)
    
    
    # 3 - Join City or County or New as one word
    pattern1 = r'\b(\w+)\s(city|CITY|City|county|COUNTY|County)\b'
    pattern2 = r'\b(New|NEW|new|United|United__Arabe|East)\s(\w+)\b'
    def replace_func1(match):
        first_word = match.group(1)
        city_word = match.group(2)
        return f'{first_word}__{city_word}'
    def replace_func2(match):
        first_word = match.group(1)
        next_word = match.group(2)
        return f'{first_word}__{next_word}'
    text = re.sub(pattern1, replace_func1, text)
    text = re.sub(pattern2, replace_func2, text)
    
    # 4 - Remove repeated groups of words
    words = text.split()
    seen_words = set()
    unique_words = []
    for word in words:
        if word not in seen_words:
            seen_words.add(word)
            unique_words.append(word)
    
    # 5 - Sort location in Alphabetic order
    unique_words = [place.replace("__", " ") for place in unique_words]
    unique_words = sorted(unique_words)
    text = " ".join(unique_words)
    
    # 6 - Remove words with length lower than 2
    text = " ".join([word for word in text.split() if len(word) >= 2])
    #"""
    
    # Desiré processing
    #text = find_indices(tweet, text)
    
    # 7 - Return location
    if not isinstance(text, str) or not text.strip():
        return " "
    return text.strip()

In [72]:
class DynamicTextAligner:
    """
    **LMR-Text Local Alignment Search Class**
    Idea: The idea is to inspire from BLAST, Basic Local Alignment Search Tool for genomics data 
    and develop light and simple alignment search tool for LMR text. We have to take raw 
    prediction from the model and find a match within the initl tweet to identify the correct 
    word the model is trying to predict.
    """
    def __init__(self, text, subtext):
        self.text = text
        self.subtext = subtext
        self.subtext_chunks = subtext.split()
        self.chunk_ids = list(range(len(self.subtext_chunks)))
        self.text_words_offsets = self._get_text_word_offsets()
    
    def _get_text_word_offsets(self):
        words = self.text.split()
        word_offsets = []
        current_position = 0
        
        for word in words:
            start_offset = self.text.find(word, current_position)
            end_offset = start_offset + len(word) - 1
            word_offsets.append({
                'word': word,
                'start_offset': start_offset,
                'end_offset': end_offset
            })
            current_position = end_offset + 1

        return word_offsets
    
    def find_chunk_positions(self):
        results = []
        text_len = len(self.text)

        current_pos = 0
        for idx, chunk in enumerate(self.subtext_chunks):
            chunk_len = len(chunk)
            
            # Search for the chunk starting from the current position
            match = None
            for i in range(current_pos, text_len - chunk_len + 1):
                if self.text[i:i + chunk_len] == chunk:
                    match = (i, i + chunk_len - 1)
                    break
            
            if match:
                start_offset, end_offset = match
                results.append({
                    'chunk_id': idx,
                    'chunk': chunk,
                    'start_offset': start_offset,
                    'end_offset': end_offset
                })
                current_pos = end_offset + 1

        return results
    
    def merge_consecutive_words(self, words):
        merged_words = []
        i = 0
        while i < len(words):
            current_word = words[i]
            if i + 1 < len(words):
                next_word = words[i + 1]
                if current_word['end_offset'] + 2 == next_word['start_offset']:
                    merged_word = {
                        'word': f"{current_word['word']} {next_word['word']}",
                        'start_offset': current_word['start_offset'],
                        'end_offset': next_word['end_offset']
                    }
                    merged_words.append(merged_word)
                    i += 2
                    continue
            merged_words.append(current_word)
            i += 1
        return merged_words

    def heuristic_processing(self, merged_words):
        # 1 - Replace special cases
        punctuations = [",", ";", ":", "#", "(", ")", "\"", "[", "]", "?"]
        output = [word['word'].split("’")[0] for word in merged_words]
        output = [word.replace(".,", ".") for word in output]
        output = [subword for word in output for subword in word.split('/')]
        words  = [word.translate(str.maketrans('', '', ''.join(punctuations))) for word in output]

        # 2 - Handle hyphens and capital letters
        processed_words = []
        for word in words:
            if word.isupper():
                processed_words.append(word)
            else:
                processed_word = word.replace("-", " ")
                processed_words.append(processed_word)
                
        # 3 - Process dots in words
        final_words = []
        for word in processed_words:
            if "." in word and word.count(".") < 2:
                if word.endswith("Md."):
                    final_words.append(word)
                else:
                    final_words.append(word.replace(".", ""))
            else:
                final_words.append(word)
                
        final_words = sorted(final_words)
        output = " ".join(final_words)
        return output

    def get_alignment(self, mode: Literal["dict", "flat", "groups", "flat_groups", "flat_sorted_groups", "flat_sorted_groups+heur"] = "dict"):
        matches = self.find_chunk_positions()
        aligned_words = []
        remaining_word_offsets = self.text_words_offsets.copy()

        for match in matches:
            chunk_start = match['start_offset']
            chunk_end = match['end_offset']

            for i, word_info in enumerate(remaining_word_offsets):
                word_start = word_info['start_offset']
                word_end = word_info['end_offset']

                if word_start <= chunk_start and word_end >= chunk_end:
                    aligned_words.append(word_info)

                    del remaining_word_offsets[i]
                    break

        if mode == "flat":
            output = [word['word'] for word in aligned_words]
        elif mode == "groups":
            output = self.merge_consecutive_words(aligned_words)
        elif mode == "flat_groups":
            merged_words = self.merge_consecutive_words(aligned_words)
            output = " ".join([word['word'] for word in merged_words])
        elif mode == "flat_sorted_groups":
            merged_words = self.merge_consecutive_words(aligned_words)
            output = " ".join(sorted([word['word'] for word in merged_words]))
        elif mode == "flat_sorted_groups+heur":
            merged_words = self.merge_consecutive_words(aligned_words)
            output = self.heuristic_processing(merged_words)
        else:
            output = aligned_words
        return output

    
    def display_results(self):
        matches = self.find_chunk_positions()
        for match in matches:
            print(f"Chunk ID: {match['chunk_id']}, Chunk: '{match['chunk']}', Start: {match['start_offset']}, End: {match['end_offset']}")
            

def TLAST_postprocess(row):
    generated_text  = row['raw_prediction']
    targeted_text   = row['text']
    
    if not isinstance(generated_text, str) or not generated_text.strip():
        return " "
    
    # Clean Special Clean
    replacements = {
        " ##": "",
        "##": "",
    }
    for word, replacement in replacements.items():
        generated_text = generated_text.replace(word, replacement)
    
    # Call TLAST
    cleaned_text = DynamicTextAligner(targeted_text, generated_text).get_alignment(
        mode="flat_sorted_groups+heur"
    )
    return cleaned_text.strip() 

### **Data preparation**

In [79]:
# Train
train_dataset = pd.read_csv('/kaggle/input/lmr-full/train_bilou.csv')

# Dev
test_dataset = pd.read_csv('/kaggle/input/lmr-full/dev_bilou.csv')

# Full
data = pd.concat([train_dataset, test_dataset])

In [80]:
data.head()

Unnamed: 0,words,labels
0,Flash floods struck a Maryland city on Sunday ...,O O O O U-STATE O O O O O O O O O O O O
1,State of emergency declared for Maryland flood...,O O O O O U-STATE O O O
2,Other parts of Maryland also saw significant d...,O O O U-STATE O O O O O O O O O U-CITY O O O O...
3,Catastrophic Flooding Slams Ellicott City Mary...,O O O U-CITY U-CITY U-STATE O O O O O O O O
4,WATCH 1 missing after flash FLOODING devastate...,O O O O O O O U-CITY U-CITY U-STATE O


### **Modeling preparation**

- **Prepare custom label mappings**: 

<div style="padding-left: 2.5em;">Before fine-tuning, it’s essential to map the location mention labels from the BILOU format to a format that BERT can understand. This involves converting categorical labels (e.g., `B-CITY`,`B-COUNTY`, ...) into integer IDs, which the model will use during training. This mapping is critical because BERT outputs logits for each token, which are then converted back to these labels.</div>

In [81]:
# Extract unique tags from word labels
tags = set(" ".join(data.labels).split(' '))

# Create label to ID and ID to label mappings
label2id = {k: v for v, k in enumerate(tags)}
id2label = {v: k for v, k in enumerate(tags)}

# Get a look of tags
tags

{'B-CITY',
 'B-COUNTRY',
 'B-COUNTY',
 'B-DISTRICT',
 'B-HUMAN-MADE',
 'B-ISLAND',
 'B-NATURAL',
 'B-NEIGHBORHOOD',
 'B-OTHER',
 'B-ROAD',
 'B-STATE',
 'I-CITY',
 'I-HUMAN-MADE',
 'I-ISLAND',
 'I-NATURAL',
 'I-OTHER',
 'I-ROAD',
 'I-STATE',
 'L-CITY',
 'L-CONTINENT',
 'L-COUNTRY',
 'L-COUNTY',
 'L-DISTRICT',
 'L-HUMAN-MADE',
 'L-ISLAND',
 'L-NATURAL',
 'L-OTHER',
 'L-ROAD',
 'L-STATE',
 'O',
 'U-CITY',
 'U-CONTINENT',
 'U-COUNTRY',
 'U-COUNTY',
 'U-DISTRICT',
 'U-HUMAN-MADE',
 'U-ISLAND',
 'U-NATURAL',
 'U-NEIGHBORHOOD',
 'U-OTHER',
 'U-ROAD',
 'U-STATE'}

### **Setup the model and tokenizer**

- **Pretrained model for huggingface**: 

<div style="padding-left: 2.5em;">We retrieve the tokenizer and the model from Huggingface's library of pre-trained models. This allows us to leverage a model that has already been fine-tuned for a specific task, such as Named Entity Recognition (NER). The tokenizer helps preprocess the input text by converting it into a format that the model can interpret, while the model is used to make predictions based on this input.</div>

In [82]:
base_model = "rsuwaileh/IDRISI-LMR-EN-random-typebased" #"FacebookAI/roberta-base" #"bert-large-uncased"

# pull model and tokenizer
tokenizer = BertTokenizer.from_pretrained(base_model)
model = BertForTokenClassification.from_pretrained(
    base_model, #"bert-large-uncased",
    num_labels=len(id2label),
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True
)
model.to(device)

Some weights of the model checkpoint at rsuwaileh/IDRISI-LMR-EN-random-typebased were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at rsuwaileh/IDRISI-LMR-EN-random-typebased and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([49, 1024]) in the checkpoint and torch.Size([42, 1024]) in the model instantia

BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 1024, padding_idx=0)
      (position_embeddings): Embedding(512, 1024)
      (token_type_embeddings): Embedding(2, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-23): 24 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): LayerNorm((1024

- **DataCollator**: 

<div style="padding-left: 2.5em;">
A custom dataset class is created to handle the input data, applying tokenization and ensuring that sequences are properly padded or truncated to fit the model’s expected input size. The `DataCollatorForTokenClassification` from the Hugging Face `transformers` library is used to dynamically pad batches during training, making the process efficient and preventing data leakage between samples.
</div>

In [83]:
data_collator = DataCollatorForTokenClassification(tokenizer)

- Create custom datasets for training and testing

In [84]:
data['labels'].apply(lambda x: len(x.split(" "))).max()

93

In [85]:
MAX_LEN = 100

training_set = CustomDataset(train_dataset, tokenizer, MAX_LEN)
testing_set = CustomDataset(test_dataset, tokenizer, MAX_LEN)

- Define training parameters

In [87]:
TRAIN_BATCH_SIZE = 32
VALID_BATCH_SIZE = 16
EPOCHS = 2

- Setup trainer

In [88]:
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=TRAIN_BATCH_SIZE,
    per_device_eval_batch_size=VALID_BATCH_SIZE,
    warmup_steps=25,
    weight_decay=0.001,
    logging_dir='./logs',
    logging_steps=10,
    eval_strategy="steps",
    eval_steps=25,
    save_steps=50,
    save_total_limit=2,
    gradient_accumulation_steps=4,
    fp16=True,
    report_to=["none"],
    learning_rate=2e-5,
    lr_scheduler_type="cosine",
    metric_for_best_model="wer",
    greater_is_better=False,
    load_best_model_at_end=True
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=training_set,
    eval_dataset=testing_set,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


### **Fine-tuning**

In [89]:
trainer.train()

Step,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1,Wer
25,1.2062,0.105004,0.980073,0.960557,0.980073,0.970217,0.720817
50,0.0397,0.04586,0.990151,0.988225,0.990151,0.989071,0.303611
75,0.0366,0.037643,0.991683,0.989633,0.991683,0.990592,0.275782
100,0.0309,0.035468,0.991848,0.990696,0.991848,0.99112,0.295543
125,0.0266,0.035281,0.992135,0.991118,0.992135,0.99148,0.2868
150,0.0247,0.034114,0.992087,0.991089,0.992087,0.991509,0.290454
175,0.0274,0.034001,0.992218,0.991129,0.992218,0.991602,0.28321
200,0.0237,0.033455,0.992184,0.991099,0.992184,0.991572,0.285779


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


TrainOutput(global_step=224, training_loss=0.22873996131654298, metrics={'train_runtime': 1227.7064, 'train_samples_per_second': 23.445, 'train_steps_per_second': 0.182, 'total_flos': 5200013185641600.0, 'train_loss': 0.22873996131654298, 'epoch': 1.991111111111111})

### **Eval on new train**

In [90]:
train = pd.read_csv("/kaggle/input/lmr-full-train/full_train_1.csv")
train = train[~train['text'].isna()]
train.shape

(16448, 3)

In [91]:
train_predictions = make_infererence(train.text.to_list(), trainer.model, tokenizer)
train['raw_prediction'] = train_predictions
train.head()

100%|██████████| 16448/16448 [17:56<00:00, 15.28it/s]


Unnamed: 0,tweet_id,text,location,raw_prediction
1,ID_1001136696589631488,"Flash floods struck a Maryland city on Sunday,...",Maryland,Maryland
2,ID_1001136950345109504,State of emergency declared for Maryland flood...,Maryland,Maryland
3,ID_1001137334056833024,Other parts of Maryland also saw significant d...,Baltimore Maryland,Maryland Baltimore Maryland
4,ID_1001138374923579392,"Catastrophic Flooding Slams Ellicott City, Mar...",Ellicott City Maryland,El ##lic ##ott City Maryland
5,ID_1001138377717157888,WATCH: 1 missing after flash #FLOODING devasta...,Ellicott City Maryland,El ##lic ##ott City Maryland


In [92]:
_, average_wer = compute_wer_eval(train, col2='raw_prediction')
average_wer

1.0319958995537053

In [93]:
train['prediction'] = train.apply(TLAST_postprocess, axis=1)
_, average_wer = compute_wer_eval(train, col2='prediction')
average_wer

0.4624318651005626

In [94]:
#train['prediction'] = train.apply(TLAST_postprocess, axis=1)
train['prediction'] = train.apply(heuristic_postprocess_1, axis=1)
_, average_wer = compute_wer_eval(train, col2='prediction')
average_wer

0.43409873793891046

In [95]:
_

Unnamed: 0,tweet_id,text,location,raw_prediction,WER,prediction
1,ID_1001136696589631488,"Flash floods struck a Maryland city on Sunday,...",Maryland,Maryland,0.0,Maryland
2,ID_1001136950345109504,State of emergency declared for Maryland flood...,Maryland,Maryland,0.0,Maryland
3,ID_1001137334056833024,Other parts of Maryland also saw significant d...,Baltimore Maryland,Maryland Baltimore Maryland,0.0,Baltimore Maryland
4,ID_1001138374923579392,"Catastrophic Flooding Slams Ellicott City, Mar...",Ellicott City Maryland,El ##lic ##ott City Maryland,0.0,Ellicott City Maryland
5,ID_1001138377717157888,WATCH: 1 missing after flash #FLOODING devasta...,Ellicott City Maryland,El ##lic ##ott City Maryland,0.0,Ellicott City Maryland
...,...,...,...,...,...,...
73066,ID_916080760276299776,Mexico City: at least a thousand buildings dam...,Mexico City,Mexico City,0.0,Mexico City
73068,ID_916125408059445248,Rescue workers recover the body of the last pe...,Mexico City,Mexico City,0.0,Mexico City
73069,ID_916135932285341696,Donate from Facebook to Mexico Earthquake Reli...,Mexico,Mexico,0.0,Mexico
73070,ID_916146805347356672,We are helping our clients in Mexico recover f...,Mexico,Mexico,0.0,Mexico


In [96]:
_.to_csv("train_inference_tlast3.csv")

### **Utils for autocorrect**

In [69]:
#from transformers import pipeline
#from tqdm import tqdm
#tqdm.pandas()
#fix_spelling = pipeline("text2text-generation",model="oliverguhr/spelling-correction-english-base", device=device)

#def correct_spelling(text):
#    return fix_spelling(text, max_length=2048)[0]['generated_text']

#test['fixed_text'] = test['text'].progress_apply(correct_spelling)
#test.head()

### **Make prediction for Context**

In [97]:
test = pd.read_csv("/kaggle/input/lmr-full/test_dataset.csv")
test.head()

Unnamed: 0,tweet_id,text
0,ID_1001154804658286592,What is happening to the infrastructure in New...
1,ID_1001155505459486720,SOLDER MISSING IN FLOOD.. PRAY FOR EDDISON HER...
2,ID_1001155756371136512,RT @TIME: Police searching for missing person ...
3,ID_1001159445194399744,Flash Flood Tears Through Maryland Town For Se...
4,ID_1001164907587538944,Ellicott City #FLOODING Pictures: Maryland Gov...


In [None]:
ids = test["tweet_id"].values
tweets = test["text"].values

# Make prediction
test_predictions = make_infererence(tweets, trainer.model, trainer.tokenizer, max_len=MAX_LEN)

In [None]:
# submission df
test['raw_prediction'] = test_predictions

# Some Quick postprocessing
test['prediction'] = test.apply(TLAST_postprocess, axis=1)
test.head(10)

In [None]:
# Save file
test[['tweet_id', 'prediction']].to_csv("submission_bbbilou+heuristic1.csv", index=False)