# IMPORTS & CONFIG

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from datasets import load_dataset
import pickle
import numpy as np
from sklearn.metrics import classification_report

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device =", DEVICE)


Device = cuda


# LOAD EMBEDDINGS FROM PART 1

In [2]:
emb = np.load("/kaggle/input/custom-word2vec-output/embeddings.npy")  
with open("/kaggle/input/custom-word2vec-output/word_to_idx.pkl", "rb") as f:
    word_to_idx = pickle.load(f)

orig_vocab_size, EMBED_DIM = emb.shape
print("Original vocab:", orig_vocab_size, "Embedding dim:", EMBED_DIM)


PAD_ID = 0
UNK_ID = 1

new_emb = np.zeros((orig_vocab_size + 2, EMBED_DIM), dtype=np.float32)
new_emb[2:] = emb                              
new_emb[UNK_ID] = np.random.uniform(-0.01,0.01,EMBED_DIM)  

# shift word_to_idx by +2
new_w2i = {"<PAD>":0, "<UNK>":1}
for w, i in word_to_idx.items():
    new_w2i[w] = i + 2

word_to_idx = new_w2i
embedding_matrix = torch.tensor(new_emb, dtype=torch.float32)
vocab_size = embedding_matrix.shape[0]

print("Final vocab with PAD/UNK:", vocab_size)

Original vocab: 14068 Embedding dim: 100
Final vocab with PAD/UNK: 14070


# DATA PREPARATION

In [3]:
dataset = load_dataset("lhoestq/conll2003")

train_split = dataset["train"]
val_split = dataset["validation"]
test_split = dataset["test"]


dataset_infos.json: 0.00B [00:00, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/1.07M [00:00<?, ?B/s]

data/validation-00000-of-00001.parquet:   0%|          | 0.00/281k [00:00<?, ?B/s]

data/test-00000-of-00001.parquet:   0%|          | 0.00/259k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/14041 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3250 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/3453 [00:00<?, ? examples/s]

# WINDOW DATASET

In [4]:
class WindowNERDataset(Dataset):
    def __init__(self, split, word_to_idx, window_size):
        self.window_size = window_size
        self.pad = word_to_idx["<PAD>"]
        self.unk = word_to_idx["<UNK>"]
        self.samples = []

        for entry in split:
            tokens = entry["tokens"]
            labels = entry["ner_tags"]

            idxs = [word_to_idx.get(t.lower(), self.unk) for t in tokens]

            padded = [self.pad]*window_size + idxs + [self.pad]*window_size

            for i in range(len(tokens)):
                window = padded[i:i + 2*window_size + 1]
                self.samples.append((window, labels[i]))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        w, y = self.samples[idx]
        return torch.tensor(w, dtype=torch.long), torch.tensor(y, dtype=torch.long)


# FEED FORWARD TAGGER

In [5]:
class FFNTagger(nn.Module):
    def __init__(self, embedding_matrix, window=2, hidden=256, hidden2=128, num_classes=9):
        super().__init__()

        self.embedding = nn.Embedding.from_pretrained(
            embedding_matrix,
            freeze=True      
        )

        input_dim = (2*window + 1) * embedding_matrix.size(1)

        self.ffn = nn.Sequential(
            nn.Linear(input_dim, hidden),
            nn.ReLU(),
            nn.Dropout(0.3),

            nn.Linear(hidden, hidden2),
            nn.ReLU(),
            nn.Dropout(0.3),

            nn.Linear(hidden2, num_classes)
        )

    def forward(self, x):
        emb = self.embedding(x)                      # (B, window, D)
        concat = emb.view(emb.size(0), -1)           # (B, window*D)
        return self.ffn(concat)                      # (B, num_classes)


# BUILD DATASETS & LOADERS

In [6]:
WINDOW = 2
HIDDEN_DIM = 256
NUM_CLASSES = 9

BATCH_SIZE = 128
EPOCHS = 60        
LR = 3e-3          

train_ds = WindowNERDataset(train_split, word_to_idx, WINDOW)
val_ds   = WindowNERDataset(val_split, word_to_idx, WINDOW)
test_ds  = WindowNERDataset(test_split,  word_to_idx, WINDOW)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE)

print("Train samples =", len(train_ds))
print("Val samples   =", len(val_ds))
print("Test samples  =", len(test_ds))


Train samples = 203621
Val samples   = 51362
Test samples  = 46435


# TRAINING & EVALUATION

In [7]:
def evaluate(model, loader, loss_fn):
    model.eval()
    total_loss = 0
    y_true, y_pred = [], []

    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            logits = model(x)
            loss = loss_fn(logits, y)
            total_loss += loss.item()

            preds = logits.argmax(1)
            y_true += y.cpu().tolist()
            y_pred += preds.cpu().tolist()

    avg_loss = total_loss / len(loader)
    report = classification_report(y_true, y_pred, digits=4)
    return avg_loss, report


# TRAINING LOOP

In [8]:
model = FFNTagger(embedding_matrix=embedding_matrix).to(DEVICE)


optimizer = optim.Adam(model.parameters(), lr=LR)
loss_fn = nn.CrossEntropyLoss()


scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,       
    patience=3,       
    verbose=True
)

best_val_loss = float("inf")


patience_limit = 6
patience_counter = 0

for epoch in range(1, EPOCHS+1):

    model.train()
    total_train = 0

    for x, y in train_loader:
        x, y = x.to(DEVICE), y.to(DEVICE)

        optimizer.zero_grad()
        logits = model(x)
        loss = loss_fn(logits, y)
        loss.backward()
        optimizer.step()

        total_train += loss.item()

    avg_train_loss = total_train / len(train_loader)
    avg_val_loss, val_report = evaluate(model, val_loader, loss_fn)

    
    scheduler.step(avg_val_loss)

    print(f"\nEpoch {epoch}/{EPOCHS}")
    print(f"Train Loss = {avg_train_loss:.4f}")
    print(f"Val Loss   = {avg_val_loss:.4f}")
    print("Validation Report:")
    print(val_report)

    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        patience_counter = 0                     
        torch.save(model.state_dict(), "best_ffn_model.pt")
        print(" Saved best model!")
    else:
        patience_counter += 1
        print(f"No improvement ({patience_counter}/{patience_limit})")

    # --- Early stopping break ---
    if patience_counter >= patience_limit:
        print("Early stopping triggered")
        break





Epoch 1/60
Train Loss = 0.3463
Val Loss   = 0.2733
Validation Report:
              precision    recall  f1-score   support

           0     0.9485    0.9789    0.9635     42759
           1     0.7627    0.8377    0.7984      1842
           2     0.8440    0.7950    0.8188      1307
           3     0.8036    0.4303    0.5605      1341
           4     0.7429    0.2770    0.4035       751
           5     0.7236    0.8095    0.7641      1837
           6     0.6126    0.5292    0.5678       257
           7     0.6076    0.3796    0.4673       922
           8     0.6719    0.2486    0.3629       346

    accuracy                         0.9206     51362
   macro avg     0.7464    0.5873    0.6341     51362
weighted avg     0.9147    0.9206    0.9131     51362

 Saved best model!

Epoch 2/60
Train Loss = 0.2818
Val Loss   = 0.2610
Validation Report:
              precision    recall  f1-score   support

           0     0.9331    0.9915    0.9614     42759
           1     0.7899  

# FINAL TEST EVALUATION

In [9]:
print("\nLoading best model...")
model.load_state_dict(torch.load("best_ffn_model.pt"))
model.to(DEVICE)

test_loss, test_report = evaluate(model, test_loader, loss_fn)

print("\n======== FINAL TEST RESULTS ========")
print("Test Loss =", test_loss)
print(test_report)
print("====================================")



Loading best model...

Test Loss = 0.24498711664435643
              precision    recall  f1-score   support

           0     0.9507    0.9877    0.9688     38323
           1     0.8474    0.8108    0.8287      1617
           2     0.8926    0.8054    0.8467      1156
           3     0.7866    0.5924    0.6758      1661
           4     0.7992    0.5102    0.6228       835
           5     0.8826    0.7572    0.8151      1668
           6     0.7756    0.6187    0.6883       257
           7     0.7091    0.4444    0.5464       702
           8     0.6512    0.5185    0.5773       216

    accuracy                         0.9335     46435
   macro avg     0.8106    0.6717    0.7300     46435
weighted avg     0.9286    0.9335    0.9289     46435

