In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast, BertModel  # Changed to Fast tokenizer
from sklearn.model_selection import train_test_split
import numpy as np
from tqdm import tqdm
import os

# Define constants
MAX_LEN = 128
BATCH_SIZE = 16
EPOCHS = 5
LEARNING_RATE = 2e-5
BERT_MODEL = 'bert-base-cased'  # Using cased variant as NER is case-sensitive
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# CoNLL-2003 has these entity types
tag2idx = {
    'O': 0,
    'B-PER': 1, 'I-PER': 2,
    'B-ORG': 3, 'I-ORG': 4,
    'B-LOC': 5, 'I-LOC': 6,
    'B-MISC': 7, 'I-MISC': 8
}
idx2tag = {v: k for k, v in tag2idx.items()}

class CoNLLDataset(Dataset):
    def __init__(self, texts, tags, tokenizer, max_len):
        self.texts = texts
        self.tags = tags
        self.tokenizer = tokenizer
        self.max_len = max_len
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        words = self.texts[idx]  # Already a list of words
        tag_list = self.tags[idx]
        
        # For debugging
        if idx == 0:
            print(f"Processing text: {' '.join(words[:10])}...")
            print(f"Words length: {len(words)}, Tags length: {len(tag_list)}")
        
        # Tokenize input text - words are already split
        encoding = self.tokenizer(
            words,  # Already a list of words, no need for split()
            is_split_into_words=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        # Get word_ids for alignment
        word_ids = encoding.word_ids(batch_index=0)
        
        # Align tags with wordpiece tokens
        aligned_labels = []
        previous_word_idx = None
        
        for word_idx in word_ids:
            # Special tokens have word_id set to None
            if word_idx is None:
                aligned_labels.append(-100)  # -100 is ignored by PyTorch loss function
            # First token of a word
            elif word_idx != previous_word_idx:
                if word_idx < len(tag_list):
                    aligned_labels.append(tag_list[word_idx])
                else:
                    # Handle case where word_idx is out of range
                    aligned_labels.append(0)  # Default to 'O'
            # Subsequent subword tokens - use same label as first token
            else:
                if word_idx < len(tag_list):
                    tag_val = tag_list[word_idx]
                    # For B- tags, convert to I- for subwords
                    if tag_val % 2 == 1 and tag_val > 0:  # if B- tag (odd value)
                        aligned_labels.append(tag_val + 1)  # Convert to I- tag
                    else:
                        aligned_labels.append(tag_val)
                else:
                    aligned_labels.append(0)
            
            previous_word_idx = word_idx
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(aligned_labels, dtype=torch.long)
        }

In [None]:
def load_conll2003():
    # Use the actual path to your CoNLL-2003 data
    file_path = "/kaggle/input/datasetnernew/eng.train"
    sentences = []
    tags_list = []
    
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            sentence = []
            sentence_tags = []
            
            for line in f:
                line = line.strip()
                
                # Empty line indicates end of sentence
                if not line:
                    if sentence:
                        sentences.append([word.lower() for word in sentence])   # Make a copy
                        tags_list.append(sentence_tags.copy())  # Make a copy
                        sentence.clear()
                        sentence_tags.clear()
                    continue
                
                # Skip document separator
                if line.startswith('-DOCSTART-'):
                    continue
                
                # Parse the token and its NER tag
                parts = line.split()
                if len(parts) >= 4:  # CoNLL-2003 has token, POS, chunk, NER format
                    token = parts[0]
                    ner_tag = parts[3]
                    
                    sentence.append(token)
                    sentence_tags.append(tag2idx.get(ner_tag, 0))  # Default to 'O' if tag not found
                
                # Print some sample data for debugging
                if len(sentences) == 0 and len(sentence) < 10:
                    print(f"Sample token: {token}, tag: {ner_tag}")
            
            # Add the last sentence if the file doesn't end with an empty line
            if sentence:
                sentences.append(sentence.copy())
                tags_list.append(sentence_tags.copy())
                
        print(f"Loaded {len(sentences)} sentences")
        if len(sentences) > 0:
            print(f"First sentence has {len(sentences[0])} tokens")
            
        return sentences, tags_list
        
    except Exception as e:
        print(f"Error loading data: {e}")
        # Return empty lists in case of error
        return [], []

In [3]:
class BERTSeq2SeqForNER(nn.Module):
    def __init__(self, bert_model_name, num_labels):
        super(BERTSeq2SeqForNER, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.dropout = nn.Dropout(0.1)
        self.num_labels = num_labels
        
        # Encoder (BERT)
        # The BERT model will encode the input text
        
        # Decoder
        # Simple linear layer for token classification
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
        
        # Additional seq2seq components
        self.lstm = nn.LSTM(
            input_size=self.bert.config.hidden_size,
            hidden_size=self.bert.config.hidden_size // 2,
            num_layers=2,
            batch_first=True,
            bidirectional=True
        )
    
    def forward(self, input_ids, attention_mask, labels=None):
        # Get BERT embeddings
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state
        
        # Process through LSTM
        lstm_output, _ = self.lstm(sequence_output)
        
        # Apply dropout
        lstm_output = self.dropout(lstm_output)
        
        # Apply classifier
        logits = self.classifier(lstm_output)
        
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            # Only keep active parts of the loss
            active_loss = attention_mask.view(-1) == 1
            active_logits = logits.view(-1, self.num_labels)
            active_labels = torch.where(
                active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
            )
            loss = loss_fct(active_logits, active_labels)
        
        return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits}


In [4]:
def train_model():
    # Load dataset
    print("Loading CoNLL-2003 dataset...")
    sentences, tags_list = load_conll2003()
    
    if not sentences:
        print("Failed to load data. Exiting.")
        return
    
    # Split data
    train_sentences, val_sentences, train_tags, val_tags = train_test_split(sentences, tags_list, test_size=0.1)
    
    print(f"Train set: {len(train_sentences)} sentences")
    print(f"Validation set: {len(val_sentences)} sentences")
    
    # Initialize tokenizer - Use Fast version
    tokenizer = BertTokenizerFast.from_pretrained(BERT_MODEL)
    
    # Create datasets
    train_dataset = CoNLLDataset(train_sentences, train_tags, tokenizer, MAX_LEN)
    val_dataset = CoNLLDataset(val_sentences, val_tags, tokenizer, MAX_LEN)
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
    
    # Initialize model
    model = BERTSeq2SeqForNER(BERT_MODEL, len(tag2idx))
    model.to(DEVICE)
    
    # Initialize optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    total_steps = len(train_loader) * EPOCHS
    scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.1, total_iters=total_steps)
    
    # Training loop
    for epoch in range(EPOCHS):
        model.train()
        total_loss = 0
        
        for batch in tqdm(train_loader, desc=f'Epoch {epoch + 1}/{EPOCHS}'):
            # Move batch to device
            input_ids = batch['input_ids'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            labels = batch['labels'].to(DEVICE)
            
            # Zero gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs['loss']
            
            # Backward pass
            loss.backward()
            
            # Update weights
            optimizer.step()
            scheduler.step()
            
            total_loss += loss.item()
        
        # Calculate average loss
        avg_train_loss = total_loss / len(train_loader)
        print(f'Average training loss: {avg_train_loss}')
        
        # Evaluation
        model.eval()
        eval_loss = 0
        predictions = []
        true_labels = []
        
        with torch.no_grad():
            for batch in tqdm(val_loader, desc='Evaluating'):
                input_ids = batch['input_ids'].to(DEVICE)
                attention_mask = batch['attention_mask'].to(DEVICE)
                labels = batch['labels'].to(DEVICE)
                
                outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs['loss']
                logits = outputs['logits']
                
                eval_loss += loss.item()
                
                # Convert logits to predictions
                preds = torch.argmax(logits, dim=2)
                
                # Move predictions and labels to CPU
                preds = preds.detach().cpu().numpy()
                labels_np = labels.detach().cpu().numpy()
                
                # Store predictions and true labels
                for i, p in enumerate(preds):
                    label = labels_np[i]
                    mask = label != -100
                    predictions.append(p[mask])
                    true_labels.append(label[mask])
        
        # Calculate average validation loss
        avg_val_loss = eval_loss / len(val_loader)
        print(f'Validation loss: {avg_val_loss}')
        
        # Calculate and print metrics
        flat_predictions = [p for sublist in predictions for p in sublist]
        flat_true_labels = [l for sublist in true_labels for l in sublist]
        
        # Calculate accuracy
        accuracy = sum(p == t for p, t in zip(flat_predictions, flat_true_labels)) / len(flat_true_labels)
        print(f'Accuracy: {accuracy:.4f}')
    
    # Save the model
    torch.save(model.state_dict(), 'bert_seq2seq_ner.pt')
    print("Model saved!")

In [5]:
def predict(text):
    # Load the model
    model = BERTSeq2SeqForNER(BERT_MODEL, len(tag2idx))
    model.load_state_dict(torch.load('bert_seq2seq_ner.pt'))
    model.to(DEVICE)
    model.eval()
    
    # Initialize tokenizer - Use Fast version
    tokenizer = BertTokenizerFast.from_pretrained(BERT_MODEL)
    
    # Tokenize input
    words = text.lower().split()
    inputs = tokenizer(
        words,
        is_split_into_words=True,
        return_tensors='pt',
        padding=True,
        truncation=True,
        max_length=MAX_LEN
    )
    
    # Move to device
    input_ids = inputs['input_ids'].to(DEVICE)
    attention_mask = inputs['attention_mask'].to(DEVICE)
    
    # Get predictions
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs['logits']
        predictions = torch.argmax(logits, dim=2).cpu().numpy()[0]
    
    # Get word to token mapping
    word_ids = inputs.word_ids(batch_index=0)
    
    # Map predictions to words
    word_predictions = []
    prev_word_idx = None
    
    for token_idx, word_idx in enumerate(word_ids):
        if word_idx is None or word_idx == prev_word_idx:
            continue
        
        word = words[word_idx]
        tag_idx = predictions[token_idx]
        tag = idx2tag.get(tag_idx, "O")
        
        word_predictions.append((word, tag))
        prev_word_idx = word_idx
    
    return word_predictions


In [8]:
if __name__ == "__main__":
    train_model()
    
    # Example usage
    sample_text = "Apple Inc. is planning to open a new store in Berlin, Germany next year."
    predictions = predict(sample_text)
    
    for word, tag in predictions:
        print(f"{word}: {tag}")

Loading CoNLL-2003 dataset...
Sample token: EU, tag: B-ORG
Sample token: rejects, tag: O
Sample token: German, tag: B-MISC
Sample token: call, tag: O
Sample token: to, tag: O
Sample token: boycott, tag: O
Sample token: British, tag: B-MISC
Sample token: lamb, tag: O
Sample token: ., tag: O
Loaded 14041 sentences
First sentence has 9 tokens
Train set: 12636 sentences
Validation set: 1405 sentences


tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/436M [00:00<?, ?B/s]

Epoch 1/5:  87%|████████▋ | 686/790 [04:42<00:43,  2.38it/s]

Processing text: thomas bjorn ( denmark ) , fernando roca ( spain...
Words length: 13, Tags length: 13


Epoch 1/5: 100%|██████████| 790/790 [05:26<00:00,  2.42it/s]


Average training loss: 0.4715768903305259


Evaluating:   1%|          | 1/88 [00:00<00:16,  5.14it/s]

Processing text: the most deflating double fault came when oncins was serving...
Words length: 22, Tags length: 22


Evaluating: 100%|██████████| 88/88 [00:12<00:00,  7.02it/s]


Validation loss: 0.13998553008687767
Accuracy: 0.9624


Epoch 2/5:  60%|█████▉    | 472/790 [03:18<02:13,  2.38it/s]

Processing text: thomas bjorn ( denmark ) , fernando roca ( spain...
Words length: 13, Tags length: 13


Epoch 2/5: 100%|██████████| 790/790 [05:32<00:00,  2.38it/s]


Average training loss: 0.10554427271473069


Evaluating:   1%|          | 1/88 [00:00<00:12,  7.14it/s]

Processing text: the most deflating double fault came when oncins was serving...
Words length: 22, Tags length: 22


Evaluating: 100%|██████████| 88/88 [00:12<00:00,  7.01it/s]


Validation loss: 0.09801412815101106
Accuracy: 0.9741


Epoch 3/5:  15%|█▌        | 119/790 [00:50<04:44,  2.36it/s]

Processing text: thomas bjorn ( denmark ) , fernando roca ( spain...
Words length: 13, Tags length: 13


Epoch 3/5: 100%|██████████| 790/790 [05:31<00:00,  2.38it/s]


Average training loss: 0.055966960808521586


Evaluating:   1%|          | 1/88 [00:00<00:11,  7.25it/s]

Processing text: the most deflating double fault came when oncins was serving...
Words length: 22, Tags length: 22


Evaluating: 100%|██████████| 88/88 [00:12<00:00,  7.03it/s]


Validation loss: 0.08917587000178173
Accuracy: 0.9766


Epoch 4/5:  35%|███▍      | 276/790 [01:56<03:37,  2.37it/s]

Processing text: thomas bjorn ( denmark ) , fernando roca ( spain...
Words length: 13, Tags length: 13


Epoch 4/5: 100%|██████████| 790/790 [05:31<00:00,  2.38it/s]


Average training loss: 0.034365553230554145


Evaluating:   1%|          | 1/88 [00:00<00:12,  7.06it/s]

Processing text: the most deflating double fault came when oncins was serving...
Words length: 22, Tags length: 22


Evaluating: 100%|██████████| 88/88 [00:12<00:00,  7.04it/s]


Validation loss: 0.08490867409008471
Accuracy: 0.9798


Epoch 5/5:  32%|███▏      | 254/790 [01:46<03:45,  2.38it/s]

Processing text: thomas bjorn ( denmark ) , fernando roca ( spain...
Words length: 13, Tags length: 13


Epoch 5/5: 100%|██████████| 790/790 [05:31<00:00,  2.38it/s]


Average training loss: 0.024484963720381447


Evaluating:   1%|          | 1/88 [00:00<00:12,  7.20it/s]

Processing text: the most deflating double fault came when oncins was serving...
Words length: 22, Tags length: 22


Evaluating: 100%|██████████| 88/88 [00:12<00:00,  7.03it/s]


Validation loss: 0.08469250687366267
Accuracy: 0.9801
Model saved!


  model.load_state_dict(torch.load('bert_seq2seq_ner.pt'))


Apple: B-ORG
Inc.: I-ORG
is: O
planning: O
to: O
open: O
a: O
new: O
store: O
in: O
Berlin,: O
Germany: O
next: O
year.: O
