# Hybrid Model Batch Fine-Tuning with Label
This notebook implements a hybrid NER system that integrates transformer embeddings (SciBERT),
external knowledge features from SciSpaCy, a BiLSTM layer (optional), and a CRF for sequence decoding.
It supports batch-wise fine-tuning and dynamically updates the classifier and CRF layers when new labels appear.
The code retains learned weights for base layers and handles device mismatches and optimizer state loading.
To address CUDA out-of-memory issues, checkpoints are loaded on CPU and the batch size has been reduced.

In [None]:
# In[1] - Environment Setup

# Uncomment and run these if the packages are not yet installed:
# !pip install torch torchvision torchaudio
# !pip install transformers
# !pip install scispacy
# !pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/en_ner_bc5cdr_md-0.5.0.tar.gz
# !pip install torchcrf
# !pip install seqeval

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchcrf import CRF

from transformers import AutoModel, AutoTokenizer
import spacy
from seqeval.metrics import f1_score, precision_score, recall_score, classification_report

import json
import numpy as np
import random
import os
import glob
import re
from torch.utils.data.dataloader import default_collate

In [None]:
# In[2] - GPU Check

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

In [None]:
# In[3] - Load SciSpaCy Model for Knowledge Features

try:
    nlp = spacy.load("en_ner_bc5cdr_md")
    print("Loaded SciSpaCy model: en_ner_bc5cdr_md")
except Exception as e:
    nlp = None
    print("Could not load SciSpaCy model. Error:", e)

In [None]:
# In[4] - Define a Simple Knowledge Feature Extraction Function

def get_knowledge_features(tokens):
    """
    For each token in the sentence, return a binary feature (0 or 1)
    indicating whether it is part of a recognized entity according to SciSpaCy.
    """
    if nlp is None:
        return [0] * len(tokens)
    
    text = " ".join(tokens)
    doc = nlp(text)
    feats = [0] * len(tokens)
    for ent in doc.ents:
        for i in range(ent.start, ent.end):
            if i < len(feats):
                feats[i] = 1
    return feats

In [None]:
model_name = "allenai/scibert_scivocab_uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
# In[6] - Define the NER Dataset Class

class NERDataset(Dataset):
    def __init__(self, file_path, tokenizer, label2id, max_length=128):
        self.tokenizer = tokenizer
        self.label2id = label2id
        self.max_length = max_length
        self.samples = []
        with open(file_path, "r", encoding="utf-8") as f:
            for line in f:
                item = json.loads(line)
                self.samples.append(item)
                
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        item = self.samples[idx]
        tokens = item["tokens"]
        tags = item["tags"]
        knowledge_feats = get_knowledge_features(tokens)
        
        encoding = self.tokenizer(
            tokens,
            is_split_into_words=True,
            return_tensors="pt",
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_offsets_mapping=True
        )
        input_ids = encoding['input_ids'].squeeze(0)
        attention_mask = encoding['attention_mask'].squeeze(0)
        offset_mapping = encoding['offset_mapping'].squeeze(0)
        
        ner_ids = []
        knowledge_ids = []
        current_word_idx = 0
        current_label = self.label2id["O"]  # default label
        
        for i, offsets in enumerate(offset_mapping):
            if offsets[0] == 0 and offsets[1] != 0:
                if current_word_idx < len(tags):
                    current_label = self.label2id.get(tags[current_word_idx], self.label2id["O"])
                    ner_ids.append(current_label)
                    knowledge_ids.append(knowledge_feats[current_word_idx])
                else:
                    ner_ids.append(self.label2id["O"])
                    knowledge_ids.append(0)
                current_word_idx += 1
            else:
                ner_ids.append(current_label)
                knowledge_ids.append(knowledge_feats[current_word_idx-1] if current_word_idx > 0 else 0)
                
        ner_ids = ner_ids[:self.max_length]
        knowledge_ids = knowledge_ids[:self.max_length]
        
        ner_ids = torch.tensor(ner_ids, dtype=torch.long)
        knowledge_ids = torch.tensor(knowledge_ids, dtype=torch.float)
        
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'ner_labels': ner_ids,
            'knowledge_feats': knowledge_ids
        }

In [None]:
# In[8] - Define the NER Model

class HybridNERModel(nn.Module):
    def __init__(self, 
                 transformer_name=model_name, 
                 hidden_dim=128,
                 num_ner_labels=10,
                 knowledge_feature_dim=1,
                 use_bilstm=True):
        super(HybridNERModel, self).__init__()

        self.transformer = AutoModel.from_pretrained(transformer_name)
        transformer_hidden_size = self.transformer.config.hidden_size
        self.feature_dim = transformer_hidden_size + knowledge_feature_dim

        self.use_bilstm = use_bilstm
        self.hidden_dim = hidden_dim
        if self.use_bilstm:
            self.bilstm = nn.LSTM(
                input_size=self.feature_dim,
                hidden_size=self.hidden_dim,
                batch_first=True,
                bidirectional=True
            )
            lstm_out_dim = self.hidden_dim * 2
        else:
            lstm_out_dim = self.feature_dim

        self.num_ner_labels = num_ner_labels
        self.ner_classifier = nn.Linear(lstm_out_dim, self.num_ner_labels)
        self.crf = CRF(self.num_ner_labels)

    def forward(self, input_ids, attention_mask, knowledge_features, ner_labels=None):
        outputs = self.transformer(input_ids, attention_mask=attention_mask)
        last_hidden_state = outputs.last_hidden_state

        if knowledge_features.dim() == 2:
            knowledge_features = knowledge_features.unsqueeze(-1)
        elif knowledge_features.dim() == 4:
            knowledge_features = knowledge_features.squeeze(-1)

        combined_input = torch.cat([last_hidden_state, knowledge_features], dim=-1)

        if self.use_bilstm:
            lstm_out, _ = self.bilstm(combined_input)
        else:
            lstm_out = combined_input

        emissions = self.ner_classifier(lstm_out)

        ner_loss = None
        if ner_labels is not None:
            if ner_labels.dim() == 3:
                ner_labels = ner_labels.squeeze(1)
            elif ner_labels.dim() == 1:
                ner_labels = ner_labels.unsqueeze(1).expand(-1, emissions.shape[1])
            emissions_t = emissions.transpose(0, 1)
            labels_t = ner_labels.transpose(0, 1)
            mask_t = attention_mask.bool().transpose(0, 1)
            ner_loss = -1 * self.crf(emissions_t, labels_t, mask=mask_t)
        return emissions, ner_loss

    def decode(self, emissions, attention_mask):
        emissions_t = emissions.transpose(0, 1)
        mask_t = attention_mask.bool().transpose(0, 1)
        pred_sequences = self.crf.decode(emissions_t, mask=mask_t)
        return pred_sequences

In [None]:
# In[9] - Define Training and Evaluation Functions

from tqdm import tqdm

def train_one_epoch(model, dataloader, optimizer):
    model.train()
    total_loss = 0.0
    total_correct = 0
    total_tokens = 0
    for batch in tqdm(dataloader, desc="Training", leave=False):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        ner_labels = batch['ner_labels'].to(device)
        knowledge_feats = batch['knowledge_feats'].to(device)
        
        optimizer.zero_grad()
        emissions, ner_loss = model(input_ids, attention_mask, knowledge_feats, ner_labels=ner_labels)
        if ner_loss is None:
            continue
        ner_loss = ner_loss.mean()
        ner_loss.backward()
        optimizer.step()
        
        total_loss += ner_loss.item()
        
        with torch.no_grad():
            pred_sequences = model.decode(emissions, attention_mask)
            for preds, golds, mask in zip(pred_sequences, ner_labels, attention_mask):
                valid_len = mask.sum().item()
                preds = preds[:valid_len]
                golds = golds[:valid_len]
                preds_tensor = torch.tensor(preds, device=golds.device)
                correct = (preds_tensor == golds).sum().item()
                total_correct += correct
                total_tokens += valid_len
    avg_loss = total_loss / len(dataloader)
    accuracy = total_correct / total_tokens if total_tokens > 0 else 0.0
    torch.cuda.empty_cache()
    return avg_loss, accuracy

def evaluate(model, dataloader, id2label):
    model.eval()
    all_preds = []
    all_true = []
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            ner_labels = batch['ner_labels'].to(device)
            knowledge_feats = batch['knowledge_feats'].to(device)
            
            emissions, _ = model(input_ids, attention_mask, knowledge_feats, ner_labels=None)
            pred_sequences = model.decode(emissions, attention_mask)
            
            for preds, golds, mask in zip(pred_sequences, ner_labels, attention_mask):
                valid_len = mask.sum().item()
                preds = preds[:valid_len]
                golds = golds[:valid_len].cpu().numpy()
                pred_labels = [id2label[p] for p in preds]
                gold_labels = [id2label[g] if g != -100 else "O" for g in golds]
                all_preds.append(pred_labels)
                all_true.append(gold_labels)
    print("SeqEval Classification Report:")
    print(classification_report(all_true, all_preds))
    p = precision_score(all_true, all_preds)
    r = recall_score(all_true, all_preds)
    f1 = f1_score(all_true, all_preds)
    print(f"Precision: {p:.4f}, Recall: {r:.4f}, F1: {f1:.4f}")

In [None]:
def load_model_for_finetuning(model, optimizer, checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location="cpu")
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    print(f"Resuming training from epoch {start_epoch}")
    return start_epoch


In [None]:
def custom_collate_fn(batch):
    batch_dict = {
        "input_ids": torch.stack([item["input_ids"] for item in batch]),
        "attention_mask": torch.stack([item["attention_mask"] for item in batch]),
        "ner_labels": torch.stack([torch.tensor(item["ner_labels"]) for item in batch]).squeeze(1),
        "knowledge_feats": torch.stack([item["knowledge_feats"] for item in batch])
    }
    return batch_dict

In [None]:
def load_previous_label2id(checkpoint_path):
    label2id_path = os.path.join(checkpoint_path, "label2id.json")
    if os.path.exists(label2id_path):
        with open(label2id_path, "r") as f:
            return json.load(f)
    return None

In [None]:
def update_model_classifier(model, old_label2id, new_label2id):
    old_num_labels = len(old_label2id) if old_label2id else 0
    new_num_labels = len(new_label2id)
    old_classifier = model.ner_classifier
    new_classifier = nn.Linear(old_classifier.in_features, new_num_labels)
    with torch.no_grad():
        num_common_labels = min(old_num_labels, new_num_labels)
        new_classifier.weight[:num_common_labels, :] = old_classifier.weight[:num_common_labels, :]
        new_classifier.bias[:num_common_labels] = old_classifier.bias[:num_common_labels]
    model.ner_classifier = new_classifier
    model.crf = CRF(new_num_labels)
    return model

In [None]:
def natural_sort_key(file_name):
    numbers = re.findall(r'\d+', file_name)
    return [int(num) for num in numbers]

In [None]:
# Define directory containing batch files
batch_data_dir = "/media/smartdragon/WORK/6th Semester/22AIE315 - Natural Language Processing/Project/New_Json_Files"  # Update with your actual path

train_files = sorted(glob.glob(os.path.join(batch_data_dir, "combined_train_*.jsonl")), key=natural_sort_key)
dev_files = sorted(glob.glob(os.path.join(batch_data_dir, "combined_dev_*.jsonl")), key=natural_sort_key)
test_files = sorted(glob.glob(os.path.join(batch_data_dir, "combined_test_*.jsonl")), key=natural_sort_key)

print(f"✅ Correctly sorted {len(train_files)} training batches, {len(dev_files)} dev batches, {len(test_files)} test batches")
print(f"Found {len(train_files)} training batches, {len(dev_files)} dev batches, {len(test_files)} test batches")

checkpoint_path = "/media/smartdragon/Windows-SSD/Users/sriva/Documents/NLP/HybridModel"
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
base_checkpoint_path = "/media/smartdragon/Windows-SSD/Users/sriva/Documents/NLP/HybridModel"

# %% [code]
# Iterate over each batch and fine-tune sequentially
for batch_idx, train_file in enumerate(train_files):
    batch_number = train_file.split("_")[-1].split(".")[0]
    
    old_checkpoint_path = os.path.join(base_checkpoint_path, f"batch_{int(batch_number)-1}")
    checkpoint_path = os.path.join(base_checkpoint_path, f"batch_{batch_number}")
    print(f"🚀 Processing batch {batch_number}")
    
    os.makedirs(checkpoint_path, exist_ok=True)
    
    previous_label2id = load_previous_label2id(old_checkpoint_path)
    
    with open(train_file, "r", encoding="utf-8") as f:
        raw_data = [json.loads(line.strip()) for line in f]
    
    unique_tags = set(tag for example in raw_data for tag in example['tags'])
    label2id = {tag: i for i, tag in enumerate(sorted(unique_tags))}
    
    label2id_path = os.path.join(checkpoint_path, "label2id.json")
    with open(label2id_path, "w") as f:
        json.dump(label2id, f)
    print(f"✅ Saved label2id for batch {batch_number}")
    
    model = HybridNERModel(
        transformer_name=model_name,
        hidden_dim=128,
        num_ner_labels=len(label2id),
        knowledge_feature_dim=1,
        use_bilstm=True
    ).to(device)
    
    model_checkpoint = os.path.join(old_checkpoint_path, "model.pt")
    print(f"Looking for model checkpoint: {model_checkpoint}")
    
    if os.path.exists(model_checkpoint):
        checkpoint = torch.load(model_checkpoint, map_location="cpu")
        filtered_state_dict = {k: v for k, v in checkpoint["model_state_dict"].items()
                                if not (k.startswith("ner_classifier") or k.startswith("crf"))}
        model.load_state_dict(filtered_state_dict, strict=False)
        print(f"✅ Loaded model checkpoint for batch {batch_number} (filtered classifier and CRF parameters)")
    
    if previous_label2id and previous_label2id != label2id:
        model = update_model_classifier(model, previous_label2id, label2id)
        model = model.to(device)
        print(f"✅ Updated model classifier for batch {batch_number}")
    
    # Use a smaller batch size to reduce GPU memory usage
    train_dataset = NERDataset(train_file, tokenizer, label2id, max_length=4)
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=custom_collate_fn)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
    
    if os.path.exists(model_checkpoint) and (previous_label2id is None or previous_label2id == label2id):
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        print(f"✅ Loaded optimizer state for batch {batch_number}")
    else:
        print("⚠️ Skipping optimizer state load due to classifier update (label mapping change)")
    
    print(f"✅ Successfully set up processing for batch {batch_number}")
    
    num_finetune_epochs = 1
    for epoch in range(num_finetune_epochs):
        train_loss, train_acc = train_one_epoch(model, train_loader, optimizer)
        print(f"Batch {batch_number}, Epoch {epoch+1}, Loss: {train_loss:.4f}, Accuracy: {train_acc:.4f}")
    
    new_checkpoint_path = os.path.join(checkpoint_path, "model.pt")
    torch.save({
        "epoch": epoch + 1,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "loss": train_loss,
    }, new_checkpoint_path)
    print(f"✅ Saved model checkpoint for batch {batch_number}")

print("🎉 All batches processed successfully!")


Using device: cuda


  deserializers["tokenizer"] = lambda p: self.tokenizer.from_disk(  # type: ignore[union-attr]


Loaded SciSpaCy model: en_ner_bc5cdr_md
✅ Correctly sorted 19 training batches, 19 dev batches, 19 test batches
Found 19 training batches, 19 dev batches, 19 test batches
🚀 Processing batch 2
✅ Saved label2id for batch 2
Looking for model checkpoint: /media/smartdragon/Windows-SSD/Users/sriva/Documents/NLP/HybridModel/batch_1/model.pt
✅ Loaded model checkpoint for batch 2 (filtered classifier and CRF parameters)
✅ Updated model classifier for batch 2
⚠️ Skipping optimizer state load due to classifier update (label mapping change)
✅ Successfully set up processing for batch 2


  "ner_labels": torch.stack([torch.tensor(item["ner_labels"]) for item in batch]).squeeze(1),
Training:  11%|█         | 1051/9614 [01:45<15:07,  9.44it/s]