# Sentence Based Cross Attention BiLSTM Model

### Importing the Required Files and Libraries

In [23]:
import re
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt
from collections import Counter
from sklearn.model_selection import KFold

# Importing the relevant files
train_file = '../../Data/NCBItrainset_corpus.txt'
dev_file = '../../Data/NCBIdevelopset_corpus.txt'
test_file = '../../DataNCBItestset_corpus.txt'
model_name = '../../Models/BiLSTM_CrossAttention_NER_model.pth'

## Data Preparation

In [24]:
# Reading the dataset file
def read_dataset(file_path):
    with open(file_path, "r") as file:
        lines = file.readlines()
    return lines

def parse_dataset(lines):
    paragraphs = []
    paragraph = []

    for line in lines:
        line = line.strip()
        if line:
            paragraph.append(line)
        else:
            if paragraph:
                paragraphs.append(paragraph)
                paragraph = []

    if paragraph:
        paragraphs.append(paragraph)

    return paragraphs

def parse_paragraph(paragraph):
    sentences = []
    annotations = []
    sentence = []

    for line in paragraph:
        if re.match(r'^\d+\|\w\|', line):
            sentence.extend(line.split('|')[2].split())

        elif re.match(r'^\d+\t\d+\t\d+\t', line):
            start, end = int(line.split("\t")[1]), int(line.split("\t")[2])
            annotations.append((start, end, line.split("\t")[3], line.split("\t")[4]))

    if sentence:
        sentences.append(sentence)
    return sentences, annotations

In [25]:
# Data Labelling
def tag_annotations(sentences, annotations):
    tagged_sentences = []

    for sentence in sentences:
        tags = ['O'] * len(sentence)  # Set tags at "O"
        word_starts = []
        word_ends = []
        char_pos = 0

        for word in sentence:
            word_starts.append(char_pos)
            char_pos += len(word)
            word_ends.append(char_pos)
            char_pos += 1  # WhiteSpace Character

        # Based on the character limits, change the annotations
        # A custom IO tagging scheme is used
        # Labels are assigned on the basis of disease label in annotations
        for start, end, disease_info, label in annotations:
            for i, (word_start, word_end) in enumerate(zip(word_starts, word_ends)):
                if word_start >= start and word_end <= end:
                    tags[i] = 'I-' + label
                elif word_start < start < word_end or word_start < end < word_end:
                    tags[i] = 'I-' + label

        tagged_sentences.append((sentence, tags))

    return tagged_sentences

def format_for_model(tagged_data):
    formatted_data = []
    for words, tags in tagged_data:
        sentence_data = '\n'.join([f'{word}\t{tag}' for word, tag in zip(words, tags)])
        formatted_data.append(sentence_data)
    return formatted_data

### Extracting the Data

In [26]:
# Load and preprocess data
lines = read_dataset(train_file)
paragraphs = parse_dataset(lines)
all_tagged_data = []

for paragraph in paragraphs:
    sentences, annotations = parse_paragraph(paragraph)
    tagged_data = tag_annotations(sentences, annotations)
    all_tagged_data.extend(tagged_data)

formatted_data = format_for_model(all_tagged_data)

Format_Data = "../../Formatted Data/BiLSTM_CrossAttention_DataPrep.txt"

# Save formatted data to a file
with open(Format_Data, 'w') as file:
    file.write('\n\n'.join(formatted_data))

In [27]:
# Load and preprocess data
dev_lines = read_dataset(dev_file)
dev_paragraphs = parse_dataset(dev_lines)
dev_all_tagged_data = []

for dev_paragraph in dev_paragraphs:
    dev_sentences, dev_annotations = parse_paragraph(dev_paragraph)
    dev_tagged_data = tag_annotations(dev_sentences, dev_annotations)
    dev_all_tagged_data.extend(dev_tagged_data)

dev_formatted_data = format_for_model(dev_all_tagged_data)

Dev_Format_Data = "../../Formatted Data/BiLSTM_CrossAttention_DevelopSet_DataPrep.txt"

# Save formatted data to a file
with open(Dev_Format_Data, 'w') as file:
    file.write('\n\n'.join(dev_formatted_data))

## Defining the Dataset Class

In [28]:
# Define the NER Dataset class
class LSTM_NER_Dataset(Dataset):
    def __init__(self, file_path):
        self.data = self.load_data(file_path)
        self.word_encoder = LabelEncoder()
        self.tag_encoder = LabelEncoder()
        self.sentences, self.tags = self.process_data()
        self.word_encoder.fit([word for sent in self.sentences for word in sent])
        self.tag_encoder.fit([tag for tag_seq in self.tags for tag in tag_seq])

    def load_data(self, file_path):
        with open(file_path, 'r') as file:
            data = file.read().split('\n\n')
        return data

    def process_data(self):
        sentences = []
        tags = []
        for entry in self.data:
            lines = entry.strip().split('\n')
            sentence = []
            tag_seq = []
            for line in lines:
                word, tag = line.split('\t')
                sentence.append(word)
                tag_seq.append(tag)
            sentences.append(sentence)
            tags.append(tag_seq)
        return sentences, tags

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        sentence = self.sentences[idx]
        tags = self.tags[idx]
        word_indices = self.word_encoder.transform(sentence)
        tag_indices = self.tag_encoder.transform(tags)
        return torch.tensor(word_indices, dtype=torch.long), torch.tensor(tag_indices, dtype=torch.long)

## Defining the Cross Attention and Model Class

In [None]:
# Cross Attention Class
class CrossAttention(nn.Module):
    def __init__(self, hidden_dim):
        super(CrossAttention, self).__init__()
        self.query = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.key = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.value = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, query, key, value):
        query_proj = self.query(query)
        key_proj = self.key(key)
        value_proj = self.value(value)
        attention_scores = torch.matmul(query_proj, key_proj.transpose(-2, -1))
        attention_weights = self.softmax(attention_scores)
        context = torch.matmul(attention_weights, value_proj)
        return context

In [None]:
# Model Class (Defining the Layers and the data flow through these layers)
class BiLSTM_CrossAttention_NER_Model(nn.Module):
    def __init__(self, vocab_size, tagset_size, embedding_dim=128, hidden_dim=128, dropout_prob=0.3):
        super(BiLSTM_CrossAttention_NER_Model, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.batch_norm_emb = nn.BatchNorm1d(embedding_dim)
        self.bilstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, bidirectional=True, dropout=dropout_prob)
        self.cross_attention = CrossAttention(hidden_dim * 2)
        self.batch_norm_att = nn.BatchNorm1d(hidden_dim * 2)
        self.dropout = nn.Dropout(dropout_prob)
        self.fc = nn.Linear(hidden_dim * 2, tagset_size)

    def forward(self, x):
        emb = self.embedding(x)
        emb = self.batch_norm_emb(emb.transpose(1, 2)).transpose(1, 2)
        emb = self.dropout(emb)
        
        bilstm_out, _ = self.bilstm(emb)
        bilstm_out = self.dropout(bilstm_out)
        
        cross_att_out = self.cross_attention(bilstm_out, bilstm_out, bilstm_out)
        cross_att_out = self.batch_norm_att(cross_att_out.transpose(1, 2)).transpose(1, 2)
        cross_att_out = self.dropout(cross_att_out)
        
        tag_space = self.fc(cross_att_out)
        return tag_space

### Graph Plotting Function

In [None]:
# Function for plotting (to be used to visualize the training loss and validation loss)
# Used to figure if the mdoel is underfitting or overfitting
def graph_plot(title, x_label, y_label, x_data, y_data, z_data, color = 'blue', linestyle = '-'):
    plt.plot(x_data, y_data, color = color, linestyle = linestyle)
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.title(title)
    plt.plot(x_data, z_data, color = 'r', linestyle = '-')
    plt.legend()
    plt.savefig("../../Graphs/aBiLSTM_CrossAttention.png", bbox_inches = 'tight')

In [None]:
# Saving the model
torch.save(model.state_dict(), model_name)