**Training the model:**

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import numpy as np
import pickle
from gensim.models import KeyedVectors
from sklearn.metrics import classification_report
from datasets import load_dataset
from torch.utils.data import TensorDataset
import gensim.downloader as api


class RNNModel(nn.Module):
    def __init__(self, input_size, hidden_dim, output_dim, n_layers=1, dropout=0.3):
        super(RNNModel, self).__init__()
        self.rnn = nn.RNN(
            input_size=input_size,
            hidden_size=hidden_dim,
            num_layers=n_layers,
            dropout=dropout,
            batch_first=True
        )
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, input_vectors):
        rnn_out, _ = self.rnn(input_vectors)
        output = self.fc(rnn_out)
        return output


def prepare_data(dataset, word2vec_model, label_encoding, max_len=128):
    input_vectors = []
    labels = []
    
    for data in dataset:
        tokens = data['tokens']
        ner_tags = data['ner_tags']
        word_vectors = []
        for token in tokens:
            if token in word2vec_model:
                word_vectors.append(word2vec_model[token])
            else:
                word_vectors.append(np.zeros(word2vec_model.vector_size))
        
        if len(word_vectors) > max_len:
            word_vectors = word_vectors[:max_len]
        else:
            pad_length = max_len - len(word_vectors)
            word_vectors.extend([np.zeros(word2vec_model.vector_size)] * pad_length)
        
        input_vectors.append(np.array(word_vectors))
        
        numerical_tags = [label_encoding.get(tag, -1) for tag in ner_tags]
        if len(numerical_tags) > max_len:
            numerical_tags = numerical_tags[:max_len]
        else:
            numerical_tags += [-1] * (max_len - len(numerical_tags))  
        
        labels.append(torch.tensor(numerical_tags))
    
    input_vectors = torch.tensor(input_vectors, dtype=torch.float32) 
    labels = torch.stack(labels)
    
    return TensorDataset(input_vectors, labels) 

def train_model(model, train_loader, num_epochs=10, lr=0.001):
    criterion = nn.CrossEntropyLoss(ignore_index=-1)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for input_vectors, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(input_vectors)
            logits_flat = outputs.view(-1, outputs.shape[-1])
            labels_flat = labels.view(-1)
            loss = criterion(logits_flat, labels_flat)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch + 1}, Loss: {total_loss:.4f}")

word2vec_model = api.load('word2vec-google-news-300')

dataset = load_dataset("surrey-nlp/PLOD-CW")
train_dataset = dataset['train']

label_encoding = {"B-O": 0, "B-AC": 1, "B-LF": 2, "I-LF": 3}

train_data = prepare_data(train_dataset, word2vec_model, label_encoding)

train_loader = data.DataLoader(train_data, batch_size=16, shuffle=True)

input_size = word2vec_model.vector_size
hidden_dim = 128
output_dim = len(label_encoding)

model = RNNModel(input_size, hidden_dim, output_dim)
train_model(model, train_loader, num_epochs=10)

  input_vectors = torch.tensor(input_vectors, dtype=torch.float32)


Epoch 1, Loss: 43.2709
Epoch 2, Loss: 32.0730
Epoch 3, Loss: 29.2046
Epoch 4, Loss: 27.7990
Epoch 5, Loss: 26.2266
Epoch 6, Loss: 25.8519
Epoch 7, Loss: 25.1497
Epoch 8, Loss: 24.8367
Epoch 9, Loss: 23.7956
Epoch 10, Loss: 23.1173


**Saving the model:**

In [2]:
def save_model(model, word2vec_model, label_encoding, file_path):
    model_data = {
        'input_size': model.rnn.input_size,
        'hidden_dim': model.rnn.hidden_size,
        'output_dim': model.fc.out_features,
        'n_layers': model.rnn.num_layers,
        'dropout': model.rnn.dropout,
        'state_dict': model.state_dict(),
        'word2vec_model': word2vec_model,
        'label_encoding': label_encoding
    }
    with open(file_path, 'wb') as f:
        pickle.dump(model_data, f)

In [3]:
save_model(model, word2vec_model, label_encoding, 'model.pkl')