In [13]:
# -----------------------
# 1️⃣ Imports
# -----------------------
# ===============================
# 1️⃣ Imports
# ===============================
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import time

from flair.data import Sentence
from flair.embeddings import WordEmbeddings, FlairEmbeddings, StackedEmbeddings

# ===============================
# 2️⃣ Load CoNLL Data
# ===============================
def load_conll(path):
    sentences, labels = [], []
    sent, lab = [], []
    with open(path, encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:  # sentence boundary
                if sent:
                    sentences.append(sent)
                    labels.append(lab)
                    sent, lab = [], []
            else:
                parts = line.split()
                if len(parts) >= 2:
                    token, tag = parts[0], parts[-1]
                    sent.append(token)
                    lab.append(tag)
        if sent:  # last sentence
            sentences.append(sent)
            labels.append(lab)
    return sentences, labels

train_texts, train_labels = load_conll("ner_train.conll")
val_texts, val_labels = load_conll("ner_val.conll")
test_texts, test_labels = load_conll("ner_test.conll")



In [14]:
# ===============================
# 3️⃣ Build tag vocab
# ===============================
all_labels = train_labels + val_labels + test_labels
unique_tags = sorted(set(tag for seq in all_labels for tag in seq))
ner_tag_to_ix = {tag: i+1 for i, tag in enumerate(unique_tags)}  # 0 = PAD
ner_tag_to_ix["PAD"] = 0
id2tag = {i: t for t, i in ner_tag_to_ix.items()}
tagset_size = len(ner_tag_to_ix)

print("NER tag vocab:", ner_tag_to_ix)

# ===============================
# 4️⃣ Flair Embeddings
# ===============================
stacked_embeddings = StackedEmbeddings([
    WordEmbeddings("glove"),
    FlairEmbeddings("news-forward"),
    FlairEmbeddings("news-backward"),
])
embedding_dim = stacked_embeddings.embedding_length

NER tag vocab: {'B-DATE': 1, 'B-LOC': 2, 'B-TIME': 3, 'I-DATE': 4, 'I-LOC': 5, 'I-TIME': 6, 'O': 7, 'PAD': 0}


In [15]:
# ===============================
# 5️⃣ Dataset
# ===============================
class NERDataset(Dataset):
    def __init__(self, texts, labels, embeddings, tag_to_ix, max_len=128):
        self.texts = texts
        self.labels = labels
        self.embeddings = embeddings
        self.tag_to_ix = tag_to_ix
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        tokens = self.texts[idx]
        tags = self.labels[idx]

        # Flair Sentence (use given tokens, no re-tokenization)
        sentence = Sentence(" ".join(tokens), use_tokenizer=False)
        self.embeddings.embed(sentence)

        emb_list, tag_list = [], []
        for tok, gold_tag in zip(sentence, tags):
            emb_list.append(tok.embedding.detach().cpu().numpy())
            tag_list.append(self.tag_to_ix[gold_tag])

        # Pad or truncate
        if len(emb_list) < self.max_len:
            pad_len = self.max_len - len(emb_list)
            emb_list.extend([[0.0]*self.embeddings.embedding_length]*pad_len)
            tag_list.extend([self.tag_to_ix["PAD"]]*pad_len)
        else:
            emb_list = emb_list[:self.max_len]
            tag_list = tag_list[:self.max_len]

        return torch.tensor(emb_list, dtype=torch.float32), torch.tensor(tag_list, dtype=torch.long)


In [16]:
# ===============================
# 6️⃣ DataLoader
# ===============================
batch_size = 16
train_dataset = NERDataset(train_texts, train_labels, stacked_embeddings, ner_tag_to_ix)
val_dataset   = NERDataset(val_texts, val_labels, stacked_embeddings, ner_tag_to_ix)
test_dataset  = NERDataset(test_texts, test_labels, stacked_embeddings, ner_tag_to_ix)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=batch_size)
test_loader  = DataLoader(test_dataset, batch_size=batch_size)

# ===============================
# 7️⃣ BiLSTM + CRF Model
# ===============================
class BiLSTM_CRF_Flair(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, tagset_size):
        super(BiLSTM_CRF_Flair, self).__init__()
        self.hidden_dim = hidden_dim
        self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2,
                            num_layers=1, bidirectional=True, batch_first=True)
        self.hidden2tag = nn.Linear(hidden_dim, tagset_size)

        # CRF parameters
        self.transitions = nn.Parameter(torch.randn(tagset_size, tagset_size))
        self.transitions.data[:, ner_tag_to_ix["PAD"]] = -10000
        self.transitions.data[ner_tag_to_ix["PAD"], :] = -10000

    def forward(self, x, tags=None, mask=None):
        lstm_out, _ = self.lstm(x)
        emissions = self.hidden2tag(lstm_out)

        if tags is not None:  # training
            loss = self.neg_log_likelihood(emissions, tags, mask)
            return loss
        else:  # inference
            return self.decode(emissions, mask)

    def neg_log_likelihood(self, emissions, tags, mask):
        # Simplified: use cross-entropy (for demo, replace with full CRF loss if needed)
        emissions = emissions.view(-1, emissions.shape[-1])
        tags = tags.view(-1)
        return nn.CrossEntropyLoss(ignore_index=ner_tag_to_ix["PAD"])(emissions, tags)

    def decode(self, emissions, mask):
        return torch.argmax(emissions, dim=-1)

# ===============================
# 8️⃣ Training
# ===============================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BiLSTM_CRF_Flair(embedding_dim, 128, tagset_size).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

n_epochs = 3
start_time = time.time()
for epoch in range(1, n_epochs+1):
    model.train()
    total_loss = 0
    for X, y in tqdm(train_loader, desc=f"Epoch {epoch}/{n_epochs}"):
        X, y = X.to(device), y.to(device)
        mask = (y != ner_tag_to_ix["PAD"])
        optimizer.zero_grad()
        loss = model(X, tags=y, mask=mask)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch}/{n_epochs}, Loss: {total_loss:.4f}")
print(f"Training completed in {time.time() - start_time:.2f} seconds")

Epoch 1/3:  11%|█         | 302/2869 [2:19:12<19:43:18, 27.66s/it]


KeyboardInterrupt: 

In [None]:
# ===============================
# 9️⃣ Evaluation
# ===============================
model.eval()
correct, total = 0, 0
with torch.no_grad():
    for X, y in tqdm(test_loader, desc="Evaluating"):
        X, y = X.to(device), y.to(device)
        mask = (y != ner_tag_to_ix["PAD"])
        preds = model(X, mask=mask)
        correct += ((preds == y) & mask).sum().item()
        total += mask.sum().item()

print(f"Test Accuracy: {correct/total:.4f}")

# ===============================
# 🔟 Save model + vocab
# ===============================
torch.save(model.state_dict(), "bilstm_crf_flair.pth")
import pickle
with open("ner_tag_to_ix.pkl", "wb") as f:
    pickle.dump(ner_tag_to_ix, f)
