In [27]:
import torch
import torch.nn as nn
import torch.optim as optim
from TorchCRF import CRF
import numpy as np

In [28]:

# Sample Data (Words and Entity Labels)
data = [
    (["John", "lives", "in", "New", "York"], ["PER", "O", "O", "LOC", "LOC"]),
    (["Google", "is", "a", "tech", "company"], ["ORG", "O", "O", "O", "O"]),
    (["Elon", "Musk", "founded", "SpaceX"], ["PER", "PER", "O", "ORG"]),
    (["Microsoft", "was", "founded", "by", "Bill", "Gates"], ["ORG", "O", "O", "O", "PER", "PER"]),
    (["Facebook", "headquarters", "is", "in", "California"], ["ORG", "O", "O", "O", "LOC"]),
    (["Jeff", "Bezos", "started", "Amazon"], ["PER", "PER", "O", "ORG"]),
    (["Paris", "is", "a", "beautiful", "city"], ["LOC", "O", "O", "O", "O"]),
    (["Apple", "launched", "a", "new", "iPhone"], ["ORG", "O", "O", "O", "O"]),
    (["Tesla", "is", "leading", "the", "EV", "market"], ["ORG", "O", "O", "O", "O", "O"]),
    (["Mark", "Zuckerberg", "founded", "Meta"], ["PER", "PER", "O", "ORG"]),
    (["The", "Eiffel", "Tower", "is", "in", "France"], ["O", "LOC", "LOC", "O", "O", "LOC"]),
    (["Cristiano", "Ronaldo", "plays", "for", "Al-Nassr"], ["PER", "PER", "O", "O", "ORG"]),
    (["Amazon", "is", "one", "of", "the", "biggest", "companies"], ["ORG", "O", "O", "O", "O", "O", "O"]),
    (["Tokyo", "is", "a", "major", "city", "in", "Japan"], ["LOC", "O", "O", "O", "O", "O", "LOC"]),
    (["Sundar", "Pichai", "is", "the", "CEO", "of", "Google"], ["PER", "PER", "O", "O", "O", "O", "ORG"]),
    (["The", "Amazon", "River", "flows", "through", "South", "America"], ["O", "LOC", "LOC", "O", "O", "LOC", "LOC"]),
    (["IBM", "was", "a", "leader", "in", "computing"], ["ORG", "O", "O", "O", "O", "O"]),
    (["The", "Statue", "of", "Liberty", "is", "in", "New", "York"], ["O", "LOC", "O", "LOC", "O", "O", "LOC", "LOC"]),
    (["Taylor", "Swift", "won", "multiple", "Grammy", "Awards"], ["PER", "PER", "O", "O", "O", "O"]),
    (["The", "Great", "Wall", "of", "China", "is", "a", "historic", "site"], ["O", "LOC", "LOC", "O", "LOC", "O", "O", "O", "O"]),
]

# Mapping Words and Labels to IDs
word2idx = {"<SOS>": 0, "<EOS>": 1, "UNK": 2} 
tag2idx = {"<SOS>": 0, "<EOS>": 1, "O": 2, "ORG": 3, "LOC": 4,"UNK":5}


# Indexing words
for sentence, tags in data:
    for word in sentence:
        if word not in word2idx:
            word2idx[word] = len(word2idx)
    for tag in tags:
        if tag not in tag2idx:
            tag2idx[tag] = len(tag2idx)
            
# Indexing tags
idx2tag = {v: k for k, v in tag2idx.items()}


In [29]:
def l1_regularization(model, l1_lambda=0.001):
    l1_norm = 0
    for param in model.parameters():
        l1_norm += param.abs().sum()  # Sum of absolute values of parameters
    return l1_lambda * l1_norm

In [30]:
def prepare_sequence(seq, mapping):
    # Add <SOS> at the start and <EOS> at the end of each sequence
    seq = ["<SOS>"] + seq + ["<EOS>"]
    return torch.tensor([mapping.get(w, mapping["UNK"]) for w in seq], dtype=torch.long)

In [31]:

train_data = [(prepare_sequence(sentence, word2idx), prepare_sequence(tags, tag2idx)) for sentence, tags in data]


In [32]:
# Early Stopping
def early_stopping(patience, validation_losses):
    if len(validation_losses) < patience:
        return False
    return np.argmin(validation_losses[-patience:]) == 0


In [33]:
class BiLSTM_CRF(nn.Module):
    def __init__(self, vocab_size, tagset_size, embedding_dim=32, hidden_dim=64, dropout=0.5):
        super(BiLSTM_CRF, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2, num_layers=1, bidirectional=True, batch_first=True)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_dim, tagset_size)
        self.crf = CRF(tagset_size)

    def forward(self, sentences, tags=None):
        # Embedding and LSTM
        embeds = self.embedding(sentences)
        lstm_out, _ = self.lstm(embeds)
        lstm_out = self.dropout(lstm_out)
        emissions = self.fc(lstm_out)

        # If tags are provided, calculate loss (during training)
        if tags is not None:
            # Masking out the <SOS> and <EOS> positions from the loss calculation
            mask = sentences != 0  # Mask where sentences are not padded
            mask &= sentences != 1  # Also mask out the <EOS> token (index 1)

            loss = -self.crf(emissions, tags, mask=mask)
            return loss
        else:
            # During inference, return the best tag sequence (Viterbi decoding)
            mask = sentences != 0  # Mask out padding
            return self.crf.viterbi_decode(emissions, mask=mask)


In [35]:

# Initialize Model, Optimizer, and Loss
model = BiLSTM_CRF(len(word2idx), len(tag2idx))
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.001) 

patience = 5 
validation_losses = []  
# Training Loop
for epoch in range(300):
    total_loss = 0
    correct_predictions = 0
    total_predictions = 0

    for words, tags in train_data:
        words, tags = words.unsqueeze(0), tags.unsqueeze(0)  # Add batch dimension
        optimizer.zero_grad()
        
        loss = model(words, tags)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        # Prediction step
        predicted_tags = model(words)
        predicted_tags = predicted_tags[0]  # Get the first sentence's predictions
        
        # Exclude SOS and EOS from accuracy calculation
        predicted_tags = predicted_tags[1:-1]  # Remove <SOS> and <EOS> tokens
        tags = tags.squeeze(0)[1:-1]  # Remove <SOS> and <EOS> tokens

        correct_predictions += sum(p == t for p, t in zip(predicted_tags, tags))    
        total_predictions += tags.nelement()  # Total number of tags

    accuracy = correct_predictions / total_predictions
    if epoch % 20 == 0:
        print(f"Epoch {epoch}, Loss: {total_loss:.4f}, Accuracy: {accuracy:.4f}")


Epoch 0, Loss: 254.8513, Accuracy: 0.2564
Epoch 20, Loss: -38.6287, Accuracy: 0.7778
Epoch 40, Loss: -265.2643, Accuracy: 0.8120
Epoch 60, Loss: -513.0708, Accuracy: 0.8291
Epoch 80, Loss: -782.4031, Accuracy: 0.8291
Epoch 100, Loss: -1018.9456, Accuracy: 0.8291
Epoch 120, Loss: -1257.9093, Accuracy: 0.8291
Epoch 140, Loss: -1474.0025, Accuracy: 0.8291
Epoch 160, Loss: -1734.5525, Accuracy: 0.8205
Epoch 180, Loss: -1932.1201, Accuracy: 0.8291
Epoch 200, Loss: -2160.5891, Accuracy: 0.8291
Epoch 220, Loss: -2527.5564, Accuracy: 0.8291
Epoch 240, Loss: -2710.4619, Accuracy: 0.8291
Epoch 260, Loss: -2946.6532, Accuracy: 0.8291
Epoch 280, Loss: -2997.5664, Accuracy: 0.8291


In [10]:
test_data = [
    (["Apple", "is", "a", "tech", "company"], ["ORG", "O", "O", "O", "O"]),
    (["Barack", "Obama", "was", "born", "in", "Hawaii"], ["PER", "PER", "O", "O", "O", "LOC"]),
    (["Microsoft", "acquired", "LinkedIn"], ["ORG", "O", "ORG"]),
]

# Convert Test Data to Index Form
test_data_idx = [(prepare_sequence(sentence, word2idx), prepare_sequence(tags, tag2idx)) for sentence, tags in test_data]

# Prepare for Inference on Test Data
model.eval()  # Switch the model to evaluation mode

correct = 0
total = 0

# Inference on Test Data (without tags, as we're predicting)
for sentence, true_tags in test_data:
    # Convert the sentence to indices
    test_seq = prepare_sequence(sentence, word2idx).unsqueeze(0)  # Add batch dimension
    
    # Perform inference
    with torch.no_grad():
        predicted_tags = model(test_seq)  # Get the predicted tag sequence
    
    # Convert predicted tags to their corresponding string labels
    predicted_tags = predicted_tags[0]  # Extract the predicted tags for this sentence
    predicted_tag_labels = [idx2tag[tag] for tag in predicted_tags]

    # Print the results for the entire sentence
    print("\nNamed Entity Recognition: Prediction Results:")
    for word, predicted_tag in zip(sentence, predicted_tag_labels):
        print(f"{word}: {predicted_tag}")
    
    # Calculate accuracy
    true_tags_idx = prepare_sequence(true_tags, tag2idx)
    correct += sum(p == t for p, t in zip(predicted_tags, true_tags_idx))
    total += len(true_tags)

accuracy = correct / total
print(f"\nAccuracy: {accuracy:.4f}")


Named Entity Recognition: Prediction Results:
Apple: LOC
is: ORG
a: O
tech: O
company: O

Named Entity Recognition: Prediction Results:
Barack: PER
Obama: PER
was: O
born: O
in: O
Hawaii: O

Named Entity Recognition: Prediction Results:
Microsoft: LOC
acquired: ORG
LinkedIn: O

Accuracy: 0.7857
