# Sentence Based Cross Attention BiLSTM Model

### Importing the Required Files and Libraries

In [41]:
import re
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt
from collections import Counter
from sklearn.model_selection import KFold
import numpy as np
import pandas as pd
from torch.cuda.amp import autocast

# Importing the relevant files
train_file = 'NCBItrainset_corpus.txt'
dev_file = 'NCBIdevelopset_corpus.txt'
test_file = '../../DataNCBItestset_corpus.txt'
model_name = 'BiLSTM_CrossAttention_NER_model.pth'
unknown_token = "<UNK>"

## Data Preparation

In [42]:
# Reading the dataset file
def read_dataset(file_path):
    with open(file_path, "r") as file:
        lines = file.readlines()
    return lines

def parse_dataset(lines):
    paragraphs = []
    paragraph = []

    for line in lines:
        line = line.strip()
        if line:
            paragraph.append(line)
        else:
            if paragraph:
                paragraphs.append(paragraph)
                paragraph = []

    if paragraph:
        paragraphs.append(paragraph)

    return paragraphs

def parse_paragraph(paragraph):
    sentences = []
    annotations = []
    sentence = []

    for line in paragraph:
        if re.match(r'^\d+\|\w\|', line):
            sentence.extend(line.split('|')[2].split())

        elif re.match(r'^\d+\t\d+\t\d+\t', line):
            start, end = int(line.split("\t")[1]), int(line.split("\t")[2])
            annotations.append((start, end, line.split("\t")[3], line.split("\t")[4]))

    if sentence:
        sentences.append(sentence)
    return sentences, annotations

In [43]:
# Data Labelling
def tag_annotations(sentences, annotations):
    tagged_sentences = []

    for sentence in sentences:
        tags = ['O'] * len(sentence)  # Set tags at "O"
        word_starts = []
        word_ends = []
        char_pos = 0

        for word in sentence:
            word_starts.append(char_pos)
            char_pos += len(word)
            word_ends.append(char_pos)
            char_pos += 1  # WhiteSpace Character

        # Based on the character limits, change the annotations
        # A custom IO tagging scheme is used
        # Labels are assigned on the basis of disease label in annotations
        for start, end, disease_info, label in annotations:
            for i, (word_start, word_end) in enumerate(zip(word_starts, word_ends)):
                if word_start >= start and word_end <= end:
                    tags[i] = 'I-' + label
                elif word_start < start < word_end or word_start < end < word_end:
                    tags[i] = 'I-' + label

        tagged_sentences.append((sentence, tags))

    return tagged_sentences

def format_for_model(tagged_data):
    formatted_data = []
    for words, tags in tagged_data:
        sentence_data = '\n'.join([f'{word}\t{tag}' for word, tag in zip(words, tags)])
        formatted_data.append(sentence_data)
    return formatted_data

### Extracting the Data

In [44]:
# Load and preprocess data
lines = read_dataset(train_file)
paragraphs = parse_dataset(lines)
all_tagged_data = []

for paragraph in paragraphs:
    sentences, annotations = parse_paragraph(paragraph)
    tagged_data = tag_annotations(sentences, annotations)
    all_tagged_data.extend(tagged_data)

formatted_data = format_for_model(all_tagged_data)

Format_Data = "../../Formatted Data/BiLSTM_CrossAttention_DataPrep.txt"

# Save formatted data to a file
with open(Format_Data, 'w') as file:
    file.write('\n\n'.join(formatted_data))

FileNotFoundError: [Errno 2] No such file or directory: 'NCBItrainset_corpus.txt'

In [None]:
# Load and preprocess data
dev_lines = read_dataset(dev_file)
dev_paragraphs = parse_dataset(dev_lines)
dev_all_tagged_data = []

for dev_paragraph in dev_paragraphs:
    dev_sentences, dev_annotations = parse_paragraph(dev_paragraph)
    dev_tagged_data = tag_annotations(dev_sentences, dev_annotations)
    dev_all_tagged_data.extend(dev_tagged_data)

dev_formatted_data = format_for_model(dev_all_tagged_data)

Dev_Format_Data = "../../Formatted Data/BiLSTM_CrossAttention_DevelopSet_DataPrep.txt"

# Save formatted data to a file
with open(Dev_Format_Data, 'w') as file:
    file.write('\n\n'.join(dev_formatted_data))

## Defining the Dataset Class

In [None]:
# Define the NER Dataset class
class NERDataset(Dataset):
    def __init__(self, sentences, tags, word_encoder, tag_encoder, unknown_token):
        self.sentences = sentences
        self.tags = tags
        self.word_encoder = word_encoder
        self.tag_encoder = tag_encoder
        self.unknown_token = unknown_token

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        sentence = self.sentences[idx]
        tags = self.tags[idx]
        word_indices = [self.word_encoder.transform([word])[0] if word in self.word_encoder.classes_ else self.word_encoder.transform([self.unknown_token])[0] for word in sentence]
        tag_indices = self.tag_encoder.transform(tags)
        return torch.tensor(word_indices, dtype=torch.long), torch.tensor(tag_indices, dtype=torch.long)

def collate_fn(batch):
    sentences, tags = zip(*batch)
    sentences_padded = torch.nn.utils.rnn.pad_sequence(sentences, batch_first=True, padding_value=0)
    tags_padded = torch.nn.utils.rnn.pad_sequence(tags, batch_first=True, padding_value=-100)
    return sentences_padded, tags_padded

## Defining the Cross Attention and BiLSTM Model Class

In [None]:
# Cross Attention Class
class CrossAttention(nn.Module):
    def __init__(self, hidden_dim):
        super(CrossAttention, self).__init__()
        self.query = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.key = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.value = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, query, key, value):
        query_proj = self.query(query)
        key_proj = self.key(key)
        value_proj = self.value(value)
        attention_scores = torch.matmul(query_proj, key_proj.transpose(-2, -1))
        attention_weights = self.softmax(attention_scores)
        context = torch.matmul(attention_weights, value_proj)
        return context

In [None]:
# Model Class (Defining the Layers and the data flow through these layers)
class BiLSTM_CrossAttention_NER_Model(nn.Module):
    def __init__(self, vocab_size, tagset_size, embedding_dim=128, hidden_dim=128, dropout_prob=0.3):
        super(BiLSTM_CrossAttention_NER_Model, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.batch_norm_emb = nn.BatchNorm1d(embedding_dim)
        self.bilstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.cross_attention = CrossAttention(hidden_dim * 2)
        self.batch_norm_att = nn.BatchNorm1d(hidden_dim * 2)
        self.dropout = nn.Dropout(dropout_prob)
        self.fc = nn.Linear(hidden_dim * 2, tagset_size)

    def forward(self, x):
        emb = self.embedding(x)
        emb = self.batch_norm_emb(emb.transpose(1, 2)).transpose(1, 2)
        emb = self.dropout(emb)
        
        bilstm_out, _ = self.bilstm(emb)
        
        cross_att_out = self.cross_attention(bilstm_out, bilstm_out, bilstm_out)
        cross_att_out = self.batch_norm_att(cross_att_out.transpose(1, 2)).transpose(1, 2)
        cross_att_out = self.dropout(cross_att_out)
        
        tag_space = self.fc(cross_att_out)
        return tag_space

### Graph Plotting Function

In [None]:
# Function for plotting (to be used to visualize the training loss and validation loss)
# Used to figure if the mdoel is underfitting or overfitting
def graph_plot(title, x_label, y_label, x_data, y_data, z_data, color = 'blue', linestyle = '-'):
    plt.plot(x_data, y_data, color = color, linestyle = linestyle)
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.title(title)
    plt.plot(x_data, z_data, color = 'red', linestyle = '-')
    plt.legend()
    plt.savefig("../../Graphs/BiLSTM_CrossAttention.png", bbox_inches = 'tight')

In [None]:
# Load and preprocess data
train_lines = read_dataset(train_file)
train_paragraphs = parse_dataset(train_lines)
all_train_tagged_data = []

for paragraph in train_paragraphs:
    sentences, annotations = parse_paragraph(paragraph)
    tagged_data = tag_annotations(sentences, annotations)
    all_train_tagged_data.extend(tagged_data)

train_sentences = [sentence for sentence, _ in all_train_tagged_data]
train_tags = [tags for _, tags in all_train_tagged_data]

# Load and preprocess validation data
dev_lines = read_dataset(dev_file)
dev_paragraphs = parse_dataset(dev_lines)
all_dev_tagged_data = []

for paragraph in dev_paragraphs:
    sentences, annotations = parse_paragraph(paragraph)
    tagged_data = tag_annotations(sentences, annotations)
    all_dev_tagged_data.extend(tagged_data)

val_sentences = [sentence for sentence, _ in all_dev_tagged_data]
val_tags = [tags for _, tags in all_dev_tagged_data]

## Function to train the Dataset

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
# Encode words and tags
word_encoder = LabelEncoder()
tag_encoder = LabelEncoder()
word_encoder.fit([word for sentence in train_sentences for word in sentence] + [unknown_token])
tag_encoder.fit([tag for tags in train_tags for tag in tags])

train_dataset = NERDataset(train_sentences, train_tags, word_encoder, tag_encoder, unknown_token)
val_dataset = NERDataset(val_sentences, val_tags, word_encoder, tag_encoder, unknown_token)

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)


cuda


In [None]:

model = BiLSTM_CrossAttention_NER_Model(len(word_encoder.classes_), len(tag_encoder.classes_)).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=-100)
optimizer = optim.AdamW(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5)
scaler = torch.cuda.amp.GradScaler()

print ("Starting Training")
model.train()
loss_dic = {}
valid_loss_dic = {}

for epoch in range(50):
    total_loss = 0
    total_valid_loss = 0

    for batch in train_dataloader:
        sentences, tags = batch
        sentences = sentences.to(device)
        tags = tags.to(device)

        optimizer.zero_grad()
        with autocast():
            outputs = model(sentences)
            outputs = outputs.view(-1, outputs.shape[-1])
            tags = tags.view(-1)
            loss = criterion(outputs, tags)
            
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_loss += loss.item()

    avg_train_loss = total_loss / len(train_dataloader)

    with torch.no_grad():
        for batch in val_dataloader:
            val_sentences, val_tags = batch
            val_sentences = val_sentences.to(device)
            val_tags = val_tags.to(device)

            with autocast():
                val_outputs = model(val_sentences)
                val_outputs = val_outputs.view(-1, val_outputs.shape[-1])
                val_tags = val_tags.view(-1)
                valid_loss = criterion(val_outputs, val_tags)
                total_valid_loss += valid_loss.item()
        
    avg_valid_loss = total_valid_loss / len(val_dataloader)

    scheduler.step(avg_train_loss)

    print(f"Epoch {epoch + 1}, Average Training Loss: {avg_train_loss:.4f}, Validation Loss: {avg_valid_loss:.4f}")
    loss_dic[epoch] = avg_train_loss
    valid_loss_dic[epoch] = avg_valid_loss

graph_plot('Training and Validation Loss', 'Epochs', 'Loss', epoch, list(loss_dic.values()), list(valid_loss_dic.values()))


Starting Training


KeyboardInterrupt: 

In [None]:
# Saving the model
torch.save(model.state_dict(), model_name)