In [1]:
# Imports
import os
import time
import pickle
import math
from collections import Counter
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score, classification_report

# Paths & device
DATA_DIR = '../../data/ner/'
MODEL_DIR = '../../models/ner/'
os.makedirs(MODEL_DIR, exist_ok=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda


In [2]:
# Load arrays
Xw_train = np.load(os.path.join(DATA_DIR, 'Xw_train.npy'))
Xw_val = np.load(os.path.join(DATA_DIR, 'Xw_val.npy'))
Xp_train = np.load(os.path.join(DATA_DIR, 'Xp_train.npy'))
Xp_val = np.load(os.path.join(DATA_DIR, 'Xp_val.npy'))
Yt_train = np.load(os.path.join(DATA_DIR, 'Yt_train.npy'))
Yt_val = np.load(os.path.join(DATA_DIR, 'Yt_val.npy'))

# Load vocabularies
with open(os.path.join(DATA_DIR, 'word2idx.pkl'), 'rb') as f:
    word2idx = pickle.load(f)
with open(os.path.join(DATA_DIR, 'pos2idx.pkl'), 'rb') as f:
    pos2idx = pickle.load(f)
with open(os.path.join(DATA_DIR, 'tag2idx.pkl'), 'rb') as f:
    tag2idx = pickle.load(f)

# Create reverse mappings
idx2word = {v: k for k, v in word2idx.items()}
idx2pos = {v: k for k, v in pos2idx.items()}
idx2tag = {v: k for k, v in tag2idx.items()}

PAD_IDX = word2idx['<PAD>']
UNK_IDX = word2idx['<UNK>']
max_seq_len = Xw_train.shape[1]

print(f'Train/Val sentences: {len(Xw_train)} / {len(Xw_val)}')
print(f'Max sequence length: {max_seq_len}')
print(f'Vocab sizes - Words: {len(word2idx)}, POS: {len(pos2idx)}, Tags: {len(tag2idx)}')

Train/Val sentences: 38367 / 9592
Max sequence length: 35
Vocab sizes - Words: 35179, POS: 43, Tags: 18


In [3]:
# Dataset & DataLoader
class NERDataset(Dataset):
    def __init__(self, words, pos, tags):
        self.words = torch.LongTensor(words)
        self.pos = torch.LongTensor(pos)
        self.tags = torch.LongTensor(tags)
    
    def __len__(self):
        return len(self.words)
    
    def __getitem__(self, idx):
        return self.words[idx], self.pos[idx], self.tags[idx]

BATCH_SIZE = 32
train_dataset = NERDataset(Xw_train, Xp_train, Yt_train)
val_dataset = NERDataset(Xw_val, Xp_val, Yt_val)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f'Train batches: {len(train_loader)}, Val batches: {len(val_loader)}')

Train batches: 1199, Val batches: 300


In [4]:
# Native CRF Implementation
class CRF(nn.Module):
    """Native CRF implementation without external dependencies"""
    
    def __init__(self, num_tags, batch_first=True):
        super(CRF, self).__init__()
        self.num_tags = num_tags
        self.batch_first = batch_first
        
        # Transition parameters: transition[i][j] = score of transitioning from tag i to tag j
        self.transitions = nn.Parameter(torch.randn(num_tags, num_tags))
        
        # Initialize transitions (don't allow transitions to PAD)
        self.transitions.data[PAD_IDX, :] = -10000  # No transitions from PAD
        self.transitions.data[:, PAD_IDX] = -10000  # No transitions to PAD
    
    def _compute_partition_function(self, emissions, mask):
        """Compute the partition function using forward algorithm"""
        batch_size, seq_length, num_tags = emissions.size()
        
        # Initialize forward variables
        forward_var = emissions[:, 0].clone()  # (batch_size, num_tags)
        
        for i in range(1, seq_length):
            # Broadcast forward_var and transitions for batch processing
            emit_score = emissions[:, i].unsqueeze(1)  # (batch_size, 1, num_tags)
            trans_score = self.transitions.unsqueeze(0)  # (1, num_tags, num_tags)
            next_tag_var = forward_var.unsqueeze(2) + trans_score + emit_score
            
            # Use logsumexp for numerical stability
            next_tag_var = torch.logsumexp(next_tag_var, dim=1)  # (batch_size, num_tags)
            
            # Apply mask
            forward_var = torch.where(mask[:, i].unsqueeze(1), next_tag_var, forward_var)
        
        # Sum over all possible ending tags
        terminal_var = torch.logsumexp(forward_var, dim=1)  # (batch_size,)
        return terminal_var
    
    def _compute_score(self, emissions, tags, mask):
        """Compute the score of a given tag sequence"""
        batch_size, seq_length = tags.size()
        
        # Compute emission scores
        emission_scores = torch.gather(emissions, 2, tags.unsqueeze(2)).squeeze(2)
        emission_scores = emission_scores * mask.float()
        emission_scores = emission_scores.sum(dim=1)  # (batch_size,)
        
        # Compute transition scores
        transition_scores = torch.zeros(batch_size, device=emissions.device)
        
        for i in range(seq_length - 1):
            curr_tags = tags[:, i]
            next_tags = tags[:, i + 1]
            
            # Get transition scores for valid positions
            valid_mask = mask[:, i + 1]
            trans_score = self.transitions[curr_tags, next_tags]
            transition_scores += trans_score * valid_mask.float()
        
        return emission_scores + transition_scores
    
    def forward(self, emissions, tags, mask=None):
        """Compute CRF negative log likelihood"""
        if mask is None:
            mask = torch.ones_like(tags, dtype=torch.bool)
        
        # Compute partition function (normalizer)
        partition = self._compute_partition_function(emissions, mask)
        
        # Compute score of the given sequence
        sequence_score = self._compute_score(emissions, tags, mask)
        
        # Return negative log likelihood
        return (partition - sequence_score).mean()
    
    def decode(self, emissions, mask=None):
        """Viterbi decoding to find the best sequence"""
        if mask is None:
            mask = torch.ones(emissions.size()[:2], dtype=torch.bool, device=emissions.device)
        
        batch_size, seq_length, num_tags = emissions.size()
        
        # Initialize
        viterbi_vars = emissions[:, 0].clone()  # (batch_size, num_tags)
        path_scores = []
        
        # Forward pass
        for i in range(1, seq_length):
            broadcast_vars = viterbi_vars.unsqueeze(2)  # (batch_size, num_tags, 1)
            broadcast_trans = self.transitions.unsqueeze(0)  # (1, num_tags, num_tags)
            next_tag_vars = broadcast_vars + broadcast_trans
            
            # Find best previous tags
            best_tag_scores, best_tags = torch.max(next_tag_vars, dim=1)
            path_scores.append(best_tags)
            
            # Add emission scores
            best_tag_scores += emissions[:, i]
            
            # Apply mask
            viterbi_vars = torch.where(mask[:, i].unsqueeze(1), best_tag_scores, viterbi_vars)
        
        # Backward pass to find best path
        best_paths = []
        
        for batch_idx in range(batch_size):
            # Find best final tag
            seq_len = mask[batch_idx].sum().item()
            if seq_len == 0:
                best_paths.append([])
                continue
                
            _, best_last_tag = torch.max(viterbi_vars[batch_idx], dim=0)
            best_path = [best_last_tag.item()]
            
            # Backtrack
            for i in range(len(path_scores) - 1, -1, -1):
                if i + 1 < seq_len:
                    best_last_tag = path_scores[i][batch_idx][best_last_tag]
                    best_path.append(best_last_tag.item())
            
            # Reverse to get correct order
            best_path.reverse()
            
            # Pad with zeros if necessary
            while len(best_path) < seq_len:
                best_path.append(0)
            
            best_paths.append(best_path[:seq_len])
        
        return best_paths

In [5]:
# BiLSTM-CRF Model
class BiLSTM_CRF_Enhanced(nn.Module):
    def __init__(self, config):
        super(BiLSTM_CRF_Enhanced, self).__init__()
        self.config = config
        self.pad_idx = PAD_IDX
        
        # Word embeddings
        self.word_embedding = nn.Embedding(
            config['vocab_size'], config['word_embed_dim'], padding_idx=self.pad_idx
        )
        
        # POS embeddings
        self.pos_embedding = nn.Embedding(
            config['pos_vocab_size'], config['pos_embed_dim']
        )
        
        # Character-level CNN for better OOV handling
        self.char_embed_dim = 30
        self.char_cnn = self._build_char_cnn()
        
        # BiLSTM layers
        total_embed_dim = (config['word_embed_dim'] + 
                          config['pos_embed_dim'] + 
                          config['char_embed_dim'])
        
        self.lstm = nn.LSTM(
            input_size=total_embed_dim,
            hidden_size=config['hidden_dim'],
            num_layers=config['num_layers'],
            batch_first=True,
            bidirectional=True,
            dropout=config['dropout'] if config['num_layers'] > 1 else 0
        )
        
        # Layer normalization
        self.layer_norm = nn.LayerNorm(config['hidden_dim'] * 2)
        
        # Dropout
        self.dropout = nn.Dropout(config['dropout'])
        
        # Linear layer for emissions
        self.hidden2tag = nn.Linear(
            config['hidden_dim'] * 2, config['tag_vocab_size']
        )
        
        # CRF layer
        self.crf = CRF(config['tag_vocab_size'], batch_first=True)
        
        # Initialize weights
        self._init_weights()
    
    def _enhanced_char_cnn(self):
        return nn.Sequential(
            nn.Conv1d(self.char_embed_dim, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveMaxPool1d(1),
            nn.Dropout(0.25)
    )
    
    def _get_char_features(self, words):
        """Extract character-level features"""
        batch_size, seq_len = words.size()
        char_features = torch.zeros(batch_size, seq_len, self.config['char_embed_dim'], 
                                   device=words.device)
        
        # Simple character features
        for i in range(batch_size):
            for j in range(seq_len):
                word_idx = words[i, j].item()
                if word_idx != self.pad_idx:
                    word = idx2word.get(word_idx, '<UNK>')
                    # Simple character encoding (first few characters)
                    for k, char in enumerate(word[:self.config['char_embed_dim']]):
                        char_features[i, j, k] = ord(char) % 128
        
        return char_features
    
    def _init_weights(self):
        """Initialize model weights"""
        # Initialize embeddings
        nn.init.uniform_(self.word_embedding.weight, -0.1, 0.1)
        nn.init.uniform_(self.pos_embedding.weight, -0.1, 0.1)
        
        # Initialize LSTM
        for name, param in self.lstm.named_parameters():
            if 'weight_ih' in name:
                nn.init.xavier_uniform_(param)
            elif 'weight_hh' in name:
                nn.init.orthogonal_(param)
            elif 'bias' in name:
                nn.init.constant_(param, 0)
        
        # Initialize linear layer
        nn.init.xavier_uniform_(self.hidden2tag.weight)
        nn.init.constant_(self.hidden2tag.bias, 0)
    
    def forward(self, words, pos, tags=None):
        batch_size, seq_len = words.size()
        
        # Create mask
        mask = (words != self.pad_idx)
        
        # Get embeddings
        word_embeds = self.word_embedding(words)
        pos_embeds = self.pos_embedding(pos)
        char_features = self._get_char_features(words)
        
        # Concatenate embeddings
        embeds = torch.cat([word_embeds, pos_embeds, char_features], dim=2)
        embeds = self.dropout(embeds)
        
        # BiLSTM
        lstm_out, _ = self.lstm(embeds)
        lstm_out = self.layer_norm(lstm_out)
        lstm_out = self.dropout(lstm_out)
        
        # Get emission scores
        emissions = self.hidden2tag(lstm_out)
        
        if tags is not None:
            # Training mode: return CRF loss
            loss = self.crf(emissions, tags, mask)
            return loss
        else:
            # Inference mode: return best sequence
            best_paths = self.crf.decode(emissions, mask)
            return best_paths

In [6]:
# Model Configuration and Initialization
CONFIG = {
    'vocab_size': len(word2idx),
    'pos_vocab_size': len(pos2idx),
    'tag_vocab_size': len(tag2idx),
    'word_embed_dim': 150,
    'pos_embed_dim': 25,
    'char_embed_dim': 50,
    'hidden_dim': 300,
    'num_layers': 2,
    'dropout': 0.3,
    'learning_rate': 0.002,
    'num_epochs': 15,
    'patience': 3,
    'max_grad_norm': 5.0
}

model = BiLSTM_CRF_Enhanced(CONFIG).to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f'Total parameters: {total_params:,}')

Total parameters: 8,724,417


In [7]:
# Optimizer and Scheduler Setup
optimizer = optim.AdamW(model.parameters(), lr=CONFIG['learning_rate'], weight_decay=1e-4)

# Warmup + Cosine decay scheduler
steps_per_epoch = len(train_loader)
total_steps = steps_per_epoch * CONFIG['num_epochs']
warmup_steps = int(0.1 * total_steps)

def lr_lambda(step):
    if step < warmup_steps:
        return step / warmup_steps
    progress = (step - warmup_steps) / (total_steps - warmup_steps)
    return 0.5 * (1 + math.cos(math.pi * progress))

scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

In [8]:
# Training Functions
def train_epoch(model, train_loader, optimizer, scheduler, epoch):
    model.train()
    total_loss = 0
    num_batches = 0
    
    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1} [Train]', leave=False)
    
    for words, pos, tags in pbar:
        words, pos, tags = words.to(device), pos.to(device), tags.to(device)
        
        optimizer.zero_grad()
        loss = model(words, pos, tags)
        
        if torch.isnan(loss) or torch.isinf(loss):
            print(f"Warning: Invalid loss detected: {loss}")
            continue
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG['max_grad_norm'])
        optimizer.step()
        scheduler.step()
        
        total_loss += loss.item()
        num_batches += 1
        
        pbar.set_postfix({
            'Loss': f'{loss.item():.4f}',
            'LR': f'{scheduler.get_last_lr()[0]:.2e}'
        })
    
    return total_loss / num_batches if num_batches > 0 else 0

def evaluate_model(model, val_loader, epoch):
    model.eval()
    all_predictions = []
    all_targets = []
    total_loss = 0
    num_batches = 0
    
    pbar = tqdm(val_loader, desc=f'Epoch {epoch+1} [Val]', leave=False)
    
    with torch.no_grad():
        for words, pos, tags in pbar:
            words, pos, tags = words.to(device), pos.to(device), tags.to(device)
            
            # Get loss
            model.train()
            loss = model(words, pos, tags)
            if not (torch.isnan(loss) or torch.isinf(loss)):
                total_loss += loss.item()
                num_batches += 1
            
            # Get predictions
            model.eval()
            predictions = model(words, pos)
            
            # Process predictions and targets
            mask = (words != PAD_IDX).cpu().numpy()
            tags_cpu = tags.cpu().numpy()
            
            for i in range(len(words)):
                seq_len = int(mask[i].sum())
                if seq_len > 0 and i < len(predictions):
                    pred_seq = predictions[i][:seq_len]
                    true_seq = tags_cpu[i][:seq_len]
                    
                    all_predictions.extend(pred_seq)
                    all_targets.extend(true_seq.tolist())
    
    avg_loss = total_loss / num_batches if num_batches > 0 else float('inf')
    
    # Calculate metrics
    if len(all_predictions) > 0 and len(all_targets) > 0:
        accuracy = np.mean(np.array(all_predictions) == np.array(all_targets))
        f1 = f1_score(all_targets, all_predictions, average='weighted', zero_division=0)
    else:
        accuracy = 0.0
        f1 = 0.0
    
    return avg_loss, accuracy, f1

In [9]:
# Training Loop with Early Stopping
print("Starting training...")
start_time = time.time()

# Training tracking
train_losses = []
val_losses = []
val_accuracies = []
val_f1_scores = []
best_f1 = 0
patience_counter = 0

for epoch in range(CONFIG['num_epochs']):
    print(f"\nEpoch {epoch+1}/{CONFIG['num_epochs']}")
    
    # Train
    train_loss = train_epoch(model, train_loader, optimizer, scheduler, epoch)
    
    # Validate
    val_loss, val_accuracy, val_f1 = evaluate_model(model, val_loader, epoch)
    
    # Store metrics
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    val_accuracies.append(val_accuracy)
    val_f1_scores.append(val_f1)
    
    print(f"Train Loss: {train_loss:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}, Val F1: {val_f1:.4f}")
    print(f"Learning Rate: {scheduler.get_last_lr()[0]:.2e}")
    
    # Early stopping based on F1 score
    if val_f1 > best_f1:
        best_f1 = val_f1
        patience_counter = 0
        
        # Save best model
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'val_loss': val_loss,
            'val_accuracy': val_accuracy,
            'val_f1': val_f1,
            'config': CONFIG,
            'vocabularies': {
                'word2idx': word2idx,
                'pos2idx': pos2idx,
                'tag2idx': tag2idx,
                'idx2word': idx2word,
                'idx2pos': idx2pos,
                'idx2tag': idx2tag
            }
        }, os.path.join(MODEL_DIR, 'best_ner_model.pt'))
        
        print(f"✓ Best model saved! Val F1: {val_f1:.4f}")
    else:
        patience_counter += 1
        print(f"No improvement. Patience: {patience_counter}/{CONFIG['patience']}")
        
        if patience_counter >= CONFIG['patience']:
            print("Early stopping triggered!")
            break

total_time = time.time() - start_time
print(f"\nTraining completed in {total_time:.2f} seconds")
print(f"Best validation F1: {best_f1:.4f}")

Starting training...

Epoch 1/15


                                                                                                                        

Train Loss: 9.7835
Val Loss: 3.4974, Val Accuracy: 0.9541, Val F1: 0.9529
Learning Rate: 1.33e-03
✓ Best model saved! Val F1: 0.9529

Epoch 2/15


                                                                                                                        

Train Loss: 2.4995
Val Loss: 1.9386, Val Accuracy: 0.9689, Val F1: 0.9683
Learning Rate: 1.99e-03
✓ Best model saved! Val F1: 0.9683

Epoch 3/15


                                                                                                                        

Train Loss: 1.6199
Val Loss: 1.7045, Val Accuracy: 0.9707, Val F1: 0.9698
Learning Rate: 1.94e-03
✓ Best model saved! Val F1: 0.9698

Epoch 4/15


                                                                                                                        

Train Loss: 1.3517
Val Loss: 1.5695, Val Accuracy: 0.9707, Val F1: 0.9701
Learning Rate: 1.84e-03
✓ Best model saved! Val F1: 0.9701

Epoch 5/15


                                                                                                                        

Train Loss: 1.2241
Val Loss: 1.5146, Val Accuracy: 0.9717, Val F1: 0.9709
Learning Rate: 1.69e-03
✓ Best model saved! Val F1: 0.9709

Epoch 6/15


                                                                                                                        

Train Loss: 1.1228
Val Loss: 1.4875, Val Accuracy: 0.9705, Val F1: 0.9702
Learning Rate: 1.50e-03
No improvement. Patience: 1/4

Epoch 7/15


                                                                                                                        

Train Loss: 1.0482
Val Loss: 1.4578, Val Accuracy: 0.9709, Val F1: 0.9707
Learning Rate: 1.29e-03
No improvement. Patience: 2/4

Epoch 8/15


                                                                                                                        

KeyboardInterrupt: 

In [None]:
# Plot Training History
plt.figure(figsize=(15, 5))

# Loss plot
plt.subplot(1, 3, 1)
plt.plot(train_losses, label='Train Loss', color='blue')
plt.plot(val_losses, label='Val Loss', color='red')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

# Accuracy plot
plt.subplot(1, 3, 2)
plt.plot(val_accuracies, label='Val Accuracy', color='green')
plt.title('Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)

# F1 score plot
plt.subplot(1, 3, 3)
plt.plot(val_f1_scores, label='Val F1', color='purple')
plt.title('Validation F1 Score')
plt.xlabel('Epoch')
plt.ylabel('F1 Score')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.savefig(os.path.join(MODEL_DIR, 'training_history.png'), dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# Training Summary
print(f"\n" + "="*60)
print("TRAINING SUMMARY")
print("="*60)
print(f"Total epochs completed: {len(train_losses)}")
print(f"Final train loss: {train_losses[-1]:.4f}")
print(f"Final val loss: {val_losses[-1]:.4f}")
print(f"Final val accuracy: {val_accuracies[-1]:.4f}")
print(f"Final val F1: {val_f1_scores[-1]:.4f}")
print(f"Best val F1: {best_f1:.4f}")
print(f"Training time: {total_time:.2f} seconds")
print(f"Model parameters: {total_params:,}")
print(f"Model saved to: {os.path.join(MODEL_DIR, 'best_ner_model.pt')}")