In [1]:
!pip install seqeval



In [2]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader
from gensim.models import Word2Vec
from seqeval.metrics import classification_report, f1_score
from tqdm import tqdm
import os

### `CNN Classification`

In [3]:
# --- 1. CONFIGURATION ---
SEQUENCE_LENGTH = 128
EMBEDDING_DIM = 100
BATCH_SIZE = 32
EPOCHS = 100
LEARNING_RATE = 0.005
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
# Paths
W2V_PATH = "w2v_med_cbow.model"  # Your TP1 Model

# CSV Files (Use lists to combine MEDLINE + EMEA)
TRAIN_FILES = [
    "TP_ISD2020/QUAERO_FrenchMed/MEDLINE/MEDLINEtrain_layer1_ID.csv",
    "TP_ISD2020/QUAERO_FrenchMed/EMEA/EMEAtrain_layer1_ID.csv",
]
VALID_FILES = [
    "TP_ISD2020/QUAERO_FrenchMed/MEDLINE/MEDLINEdev_layer1_ID.csv",
    "TP_ISD2020/QUAERO_FrenchMed/EMEA/EMEAdev_layer1_ID.csv",
]
TEST_FILES = [
    "TP_ISD2020/QUAERO_FrenchMed/MEDLINE/MEDLINEtest_layer1_ID.csv",
    "TP_ISD2020/QUAERO_FrenchMed/EMEA/EMEAtest_layer1_ID.csv",
]

In [5]:
def load_data_from_csv(file_paths):
    all_sentences = []
    all_tags = []

    for fpath in file_paths:
        if not os.path.exists(fpath):
            print(f"❌ File not found: {fpath}")
            continue

        print(f"Loading {os.path.basename(fpath)}...", end=" ")

        # Read CSV with auto-separator detection
        try:
            df = pd.read_csv(
                fpath,
                sep=None,
                engine="python",
                keep_default_na=False,
                skip_blank_lines=False,
            )
        except:
            print("Read failed.")
            continue

        # Extract Words and Tags
        if "Mot" in df.columns and "Tag" in df.columns:
            words = df["Mot"].astype(str).values
            tags = df["Tag"].astype(str).values
        else:
            words = df.iloc[:, 0].astype(str).values
            tags = df.iloc[:, -1].astype(str).values

        # Group into sentences
        curr_s, curr_t = [], []
        file_s, file_t = [], []

        for w, t in zip(words, tags):
            if not w.strip():  # Empty line = Sentence Break
                if curr_s:
                    file_s.append(curr_s)
                    file_t.append(curr_t)
                    curr_s, curr_t = [], []
            else:
                curr_s.append(w)
                curr_t.append(t)
        if curr_s:
            file_s.append(curr_s)
            file_t.append(curr_t)

        # Fallback: If chunking failed (1 giant sentence), force split
        if len(file_s) < 10 and len(words) > 500:
            print("[Chunking Fallback]", end=" ")
            flat_w = [w for s in file_s for w in s]
            flat_t = [t for s in file_t for t in s]
            file_s = [
                flat_w[i : i + SEQUENCE_LENGTH]
                for i in range(0, len(flat_w), SEQUENCE_LENGTH)
            ]
            file_t = [
                flat_t[i : i + SEQUENCE_LENGTH]
                for i in range(0, len(flat_t), SEQUENCE_LENGTH)
            ]

        print(f"-> {len(file_s)} sentences.")
        all_sentences.extend(file_s)
        all_tags.extend(file_t)

    return all_sentences, all_tags


print("--- 1. LOADING TEXT DATA ---")
train_sents, train_tags = load_data_from_csv(TRAIN_FILES)
valid_sents, valid_tags = load_data_from_csv(VALID_FILES)
test_sents, test_tags = load_data_from_csv(TEST_FILES)

print(f"Train Size: {len(train_sents)}")
print(f"Test Size:  {len(test_sents)} (Should be > 1000)")

--- 1. LOADING TEXT DATA ---
Loading MEDLINEtrain_layer1_ID.csv... [Chunking Fallback] -> 91 sentences.
Loading EMEAtrain_layer1_ID.csv... [Chunking Fallback] -> 120 sentences.
Loading MEDLINEdev_layer1_ID.csv... [Chunking Fallback] -> 90 sentences.
Loading EMEAdev_layer1_ID.csv... [Chunking Fallback] -> 106 sentences.
Loading MEDLINEtest_layer1_ID.csv... [Chunking Fallback] -> 94 sentences.
Loading EMEAtest_layer1_ID.csv... [Chunking Fallback] -> 97 sentences.
Train Size: 211
Test Size:  191 (Should be > 1000)


In [6]:
print("\n--- 2. VECTORIZING FEATURES ---")
w2v_model = Word2Vec.load(W2V_PATH)


def vectorize(sentences, model, max_len=128, dim=100):
    X = np.zeros((len(sentences), max_len, dim), dtype=np.float32)
    for i, sent in enumerate(sentences):
        length = min(len(sent), max_len)
        for j in range(length):
            word = sent[j]
            # Try Exact match -> Lowercase match -> Zero
            if word in model.wv:
                X[i, j] = model.wv[word]
            elif word.lower() in model.wv:
                X[i, j] = model.wv[word.lower()]
    return torch.tensor(X)


X_train = vectorize(train_sents, w2v_model)
X_valid = vectorize(valid_sents, w2v_model)
X_test = vectorize(test_sents, w2v_model)

# --- 4. ENCODING LABELS (Y) ---
print("--- 3. ENCODING LABELS ---")
tag_set = set(t for s in train_tags + valid_tags + test_tags for t in s)
tag2idx = {t: i + 1 for i, t in enumerate(sorted(list(tag_set)))}
tag2idx["<PAD>"] = 0
idx2tag = {v: k for k, v in tag2idx.items()}
print(f"Tags: {tag2idx}")


def encode_labels(labels, mapping, max_len=128):
    Y = []
    for s in labels:
        seq = [mapping.get(t, 0) for t in s]
        if len(seq) < max_len:
            seq += [0] * (max_len - len(seq))
        else:
            seq = seq[:max_len]
        Y.append(seq)
    return torch.tensor(Y, dtype=torch.long)


y_train = encode_labels(train_tags, tag2idx)
y_valid = encode_labels(valid_tags, tag2idx)
y_test = encode_labels(test_tags, tag2idx)

# DataLoaders
train_loader = DataLoader(
    TensorDataset(X_train, y_train), shuffle=True, batch_size=BATCH_SIZE
)
valid_loader = DataLoader(
    TensorDataset(X_valid, y_valid), shuffle=False, batch_size=BATCH_SIZE
)
test_loader = DataLoader(
    TensorDataset(X_test, y_test), shuffle=False, batch_size=BATCH_SIZE
)


--- 2. VECTORIZING FEATURES ---
--- 3. ENCODING LABELS ---
Tags: {'B-ANAT': 1, 'B-CHEM': 2, 'B-DEVI': 3, 'B-DISO': 4, 'B-GEOG': 5, 'B-LIVB': 6, 'B-OBJC': 7, 'B-PHEN': 8, 'B-PHYS': 9, 'B-PROC': 10, 'I-ANAT': 11, 'I-CHEM': 12, 'I-DEVI': 13, 'I-DISO': 14, 'I-GEOG': 15, 'I-LIVB': 16, 'I-OBJC': 17, 'I-PHEN': 18, 'I-PHYS': 19, 'I-PROC': 20, 'O': 21, '<PAD>': 0}


In [7]:
# --- 5. CNN MODEL ---
class CNN_NER(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(CNN_NER, self).__init__()
        self.conv1 = nn.Conv1d(input_dim, 128, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(128, 256, kernel_size=3, padding=1)
        self.dropout = nn.Dropout(0.5)
        self.fc = nn.Linear(256, num_classes)

    def forward(self, x):
        x = x.permute(0, 2, 1)  # [Batch, Dim, Seq]
        out = F.relu(self.conv1(x))
        out = self.dropout(out)
        out = F.relu(self.conv2(out))
        out = self.dropout(out)
        out = out.permute(0, 2, 1)  # [Batch, Seq, 256]
        return self.fc(out)


model = CNN_NER(EMBEDDING_DIM, len(tag2idx)).to(DEVICE)

# Class Weights to handle "O" dominance
weights = torch.ones(len(tag2idx)).to(DEVICE)
if "O" in tag2idx:
    weights[tag2idx["O"]] = 0.5
criterion = nn.CrossEntropyLoss(weight=weights, ignore_index=0)
optimizer = Adam(model.parameters(), lr=LEARNING_RATE)

In [8]:
# --- 6. TRAINING ---
print(f"\n--- 4. TRAINING ON {DEVICE} ---")
best_f1 = 0

for epoch in range(EPOCHS):
    model.train()
    train_loss = 0
    for x, y in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        x, y = x.to(DEVICE), y.to(DEVICE)
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out.view(-1, len(tag2idx)), y.view(-1))
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    # Validation
    model.eval()
    all_true, all_pred = [], []
    with torch.no_grad():
        for x, y in valid_loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            out = model(x)
            preds = torch.argmax(out, dim=2).cpu().numpy()
            labels = y.cpu().numpy()

            for i in range(len(x)):
                p_s, t_s = [], []
                for j in range(SEQUENCE_LENGTH):
                    if labels[i][j] == 0:
                        break
                    p_s.append(idx2tag[preds[i][j]])
                    t_s.append(idx2tag[labels[i][j]])
                all_pred.append(p_s)
                all_true.append(t_s)

    val_f1 = f1_score(all_true, all_pred)
    print(f"Loss: {train_loss/len(train_loader):.4f} | Val F1: {val_f1:.4f}")

    if val_f1 > best_f1:
        best_f1 = val_f1
        torch.save(model.state_dict(), "best_ner_cnn.pt")


--- 4. TRAINING ON cpu ---


Epoch 1: 100%|██████████| 7/7 [00:00<00:00,  8.69it/s]


Loss: 2.2331 | Val F1: 0.0590


Epoch 2: 100%|██████████| 7/7 [00:00<00:00,  9.19it/s]


Loss: 1.6785 | Val F1: 0.1137


Epoch 3: 100%|██████████| 7/7 [00:00<00:00,  9.69it/s]


Loss: 1.4758 | Val F1: 0.1504


Epoch 4: 100%|██████████| 7/7 [00:00<00:00, 10.61it/s]


Loss: 1.3675 | Val F1: 0.1971


Epoch 5: 100%|██████████| 7/7 [00:00<00:00,  8.64it/s]


Loss: 1.3073 | Val F1: 0.2392


Epoch 6: 100%|██████████| 7/7 [00:00<00:00,  9.00it/s]


Loss: 1.2584 | Val F1: 0.2745


Epoch 7: 100%|██████████| 7/7 [00:00<00:00,  9.10it/s]


Loss: 1.2243 | Val F1: 0.2883


Epoch 8: 100%|██████████| 7/7 [00:00<00:00, 10.71it/s]


Loss: 1.1908 | Val F1: 0.2958


Epoch 9: 100%|██████████| 7/7 [00:00<00:00,  8.90it/s]


Loss: 1.1624 | Val F1: 0.3216


Epoch 10: 100%|██████████| 7/7 [00:00<00:00, 10.42it/s]


Loss: 1.1276 | Val F1: 0.3349


Epoch 11: 100%|██████████| 7/7 [00:00<00:00, 10.82it/s]


Loss: 1.1060 | Val F1: 0.3427


Epoch 12: 100%|██████████| 7/7 [00:00<00:00, 10.79it/s]


Loss: 1.0920 | Val F1: 0.3521


Epoch 13: 100%|██████████| 7/7 [00:00<00:00,  9.50it/s]


Loss: 1.0809 | Val F1: 0.3507


Epoch 14: 100%|██████████| 7/7 [00:00<00:00,  9.07it/s]


Loss: 1.0708 | Val F1: 0.3377


Epoch 15: 100%|██████████| 7/7 [00:00<00:00, 10.45it/s]


Loss: 1.0567 | Val F1: 0.3397


Epoch 16: 100%|██████████| 7/7 [00:00<00:00, 10.24it/s]


Loss: 1.0361 | Val F1: 0.3517


Epoch 17: 100%|██████████| 7/7 [00:00<00:00, 10.04it/s]


Loss: 1.0388 | Val F1: 0.3405


Epoch 18: 100%|██████████| 7/7 [00:00<00:00, 10.73it/s]


Loss: 1.0383 | Val F1: 0.3591


Epoch 19: 100%|██████████| 7/7 [00:00<00:00, 10.66it/s]


Loss: 1.0099 | Val F1: 0.3598


Epoch 20: 100%|██████████| 7/7 [00:00<00:00, 10.94it/s]


Loss: 0.9950 | Val F1: 0.3542


Epoch 21: 100%|██████████| 7/7 [00:00<00:00, 11.02it/s]


Loss: 0.9868 | Val F1: 0.3589


Epoch 22: 100%|██████████| 7/7 [00:00<00:00,  9.26it/s]


Loss: 0.9841 | Val F1: 0.3708


Epoch 23: 100%|██████████| 7/7 [00:00<00:00, 11.02it/s]


Loss: 0.9750 | Val F1: 0.3516


Epoch 24: 100%|██████████| 7/7 [00:00<00:00, 10.46it/s]


Loss: 0.9701 | Val F1: 0.3690


Epoch 25: 100%|██████████| 7/7 [00:00<00:00, 11.11it/s]


Loss: 0.9612 | Val F1: 0.3741


Epoch 26: 100%|██████████| 7/7 [00:00<00:00, 10.72it/s]


Loss: 0.9587 | Val F1: 0.3707


Epoch 27: 100%|██████████| 7/7 [00:00<00:00, 10.92it/s]


Loss: 0.9554 | Val F1: 0.3691


Epoch 28: 100%|██████████| 7/7 [00:00<00:00, 10.68it/s]


Loss: 0.9438 | Val F1: 0.3773


Epoch 29: 100%|██████████| 7/7 [00:00<00:00, 10.50it/s]


Loss: 0.9444 | Val F1: 0.3927


Epoch 30: 100%|██████████| 7/7 [00:00<00:00, 10.86it/s]


Loss: 0.9419 | Val F1: 0.3831


Epoch 31: 100%|██████████| 7/7 [00:00<00:00, 11.07it/s]


Loss: 0.9182 | Val F1: 0.3926


Epoch 32: 100%|██████████| 7/7 [00:00<00:00, 10.80it/s]


Loss: 0.9229 | Val F1: 0.3890


Epoch 33: 100%|██████████| 7/7 [00:00<00:00, 10.87it/s]


Loss: 0.9315 | Val F1: 0.3754


Epoch 34: 100%|██████████| 7/7 [00:00<00:00, 11.00it/s]


Loss: 0.9276 | Val F1: 0.3847


Epoch 35: 100%|██████████| 7/7 [00:00<00:00, 11.08it/s]


Loss: 0.9110 | Val F1: 0.3660


Epoch 36: 100%|██████████| 7/7 [00:01<00:00,  6.66it/s]


Loss: 0.9116 | Val F1: 0.3864


Epoch 37: 100%|██████████| 7/7 [00:00<00:00, 10.80it/s]


Loss: 0.9112 | Val F1: 0.3802


Epoch 38: 100%|██████████| 7/7 [00:00<00:00, 10.77it/s]


Loss: 0.8956 | Val F1: 0.3865


Epoch 39: 100%|██████████| 7/7 [00:00<00:00, 10.67it/s]


Loss: 0.8869 | Val F1: 0.4050


Epoch 40: 100%|██████████| 7/7 [00:00<00:00, 10.82it/s]


Loss: 0.9015 | Val F1: 0.3966


Epoch 41: 100%|██████████| 7/7 [00:00<00:00, 11.01it/s]


Loss: 0.8987 | Val F1: 0.3903


Epoch 42: 100%|██████████| 7/7 [00:00<00:00, 11.37it/s]


Loss: 0.8820 | Val F1: 0.3956


Epoch 43: 100%|██████████| 7/7 [00:00<00:00, 11.17it/s]


Loss: 0.8838 | Val F1: 0.3830


Epoch 44: 100%|██████████| 7/7 [00:00<00:00, 11.26it/s]


Loss: 0.8862 | Val F1: 0.3569


Epoch 45: 100%|██████████| 7/7 [00:00<00:00, 11.41it/s]


Loss: 0.8922 | Val F1: 0.3986


Epoch 46: 100%|██████████| 7/7 [00:00<00:00, 11.39it/s]


Loss: 0.8816 | Val F1: 0.4005


Epoch 47: 100%|██████████| 7/7 [00:00<00:00, 11.36it/s]


Loss: 0.8726 | Val F1: 0.3929


Epoch 48: 100%|██████████| 7/7 [00:00<00:00, 11.29it/s]


Loss: 0.8819 | Val F1: 0.3847


Epoch 49: 100%|██████████| 7/7 [00:00<00:00, 10.72it/s]


Loss: 0.8674 | Val F1: 0.4004


Epoch 50: 100%|██████████| 7/7 [00:00<00:00, 10.69it/s]


Loss: 0.8540 | Val F1: 0.4000


Epoch 51: 100%|██████████| 7/7 [00:00<00:00, 10.71it/s]


Loss: 0.8473 | Val F1: 0.4014


Epoch 52: 100%|██████████| 7/7 [00:00<00:00, 11.00it/s]


Loss: 0.8549 | Val F1: 0.3834


Epoch 53: 100%|██████████| 7/7 [00:00<00:00, 10.78it/s]


Loss: 0.8622 | Val F1: 0.3999


Epoch 54: 100%|██████████| 7/7 [00:00<00:00, 11.45it/s]


Loss: 0.8484 | Val F1: 0.3975


Epoch 55: 100%|██████████| 7/7 [00:00<00:00, 11.23it/s]


Loss: 0.8594 | Val F1: 0.4084


Epoch 56: 100%|██████████| 7/7 [00:00<00:00, 11.19it/s]


Loss: 0.8441 | Val F1: 0.4020


Epoch 57: 100%|██████████| 7/7 [00:00<00:00, 11.08it/s]


Loss: 0.8372 | Val F1: 0.3947


Epoch 58: 100%|██████████| 7/7 [00:00<00:00, 11.20it/s]


Loss: 0.8456 | Val F1: 0.4000


Epoch 59: 100%|██████████| 7/7 [00:00<00:00, 11.07it/s]


Loss: 0.8372 | Val F1: 0.3979


Epoch 60: 100%|██████████| 7/7 [00:00<00:00, 11.21it/s]


Loss: 0.8380 | Val F1: 0.4010


Epoch 61: 100%|██████████| 7/7 [00:00<00:00,  9.49it/s]


Loss: 0.8344 | Val F1: 0.3959


Epoch 62: 100%|██████████| 7/7 [00:00<00:00, 11.23it/s]


Loss: 0.8289 | Val F1: 0.4018


Epoch 63: 100%|██████████| 7/7 [00:00<00:00,  9.96it/s]


Loss: 0.8262 | Val F1: 0.3939


Epoch 64: 100%|██████████| 7/7 [00:00<00:00, 10.84it/s]


Loss: 0.8355 | Val F1: 0.3981


Epoch 65: 100%|██████████| 7/7 [00:00<00:00, 10.41it/s]


Loss: 0.8086 | Val F1: 0.4120


Epoch 66: 100%|██████████| 7/7 [00:00<00:00,  8.80it/s]


Loss: 0.8325 | Val F1: 0.3989


Epoch 67: 100%|██████████| 7/7 [00:00<00:00, 10.45it/s]


Loss: 0.8272 | Val F1: 0.3948


Epoch 68: 100%|██████████| 7/7 [00:00<00:00, 10.60it/s]


Loss: 0.8222 | Val F1: 0.3946


Epoch 69: 100%|██████████| 7/7 [00:00<00:00,  9.50it/s]


Loss: 0.8160 | Val F1: 0.3984


Epoch 70: 100%|██████████| 7/7 [00:00<00:00, 10.30it/s]


Loss: 0.8189 | Val F1: 0.4091


Epoch 71: 100%|██████████| 7/7 [00:00<00:00, 11.24it/s]


Loss: 0.8086 | Val F1: 0.3888


Epoch 72: 100%|██████████| 7/7 [00:00<00:00, 11.38it/s]


Loss: 0.8073 | Val F1: 0.4066


Epoch 73: 100%|██████████| 7/7 [00:00<00:00, 11.29it/s]


Loss: 0.8150 | Val F1: 0.3984


Epoch 74: 100%|██████████| 7/7 [00:00<00:00,  9.88it/s]


Loss: 0.8107 | Val F1: 0.4100


Epoch 75: 100%|██████████| 7/7 [00:00<00:00, 11.11it/s]


Loss: 0.8179 | Val F1: 0.4071


Epoch 76: 100%|██████████| 7/7 [00:00<00:00, 11.33it/s]


Loss: 0.8088 | Val F1: 0.3919


Epoch 77: 100%|██████████| 7/7 [00:00<00:00, 10.93it/s]


Loss: 0.8055 | Val F1: 0.4115


Epoch 78: 100%|██████████| 7/7 [00:00<00:00, 11.36it/s]


Loss: 0.8076 | Val F1: 0.4059


Epoch 79: 100%|██████████| 7/7 [00:00<00:00, 11.38it/s]


Loss: 0.7992 | Val F1: 0.3897


Epoch 80: 100%|██████████| 7/7 [00:00<00:00, 11.28it/s]


Loss: 0.8050 | Val F1: 0.3896


Epoch 81: 100%|██████████| 7/7 [00:00<00:00, 11.28it/s]


Loss: 0.8060 | Val F1: 0.3913


Epoch 82: 100%|██████████| 7/7 [00:00<00:00, 11.35it/s]


Loss: 0.7999 | Val F1: 0.3930


Epoch 83: 100%|██████████| 7/7 [00:00<00:00, 11.38it/s]


Loss: 0.8029 | Val F1: 0.4070


Epoch 84: 100%|██████████| 7/7 [00:00<00:00, 11.37it/s]


Loss: 0.7966 | Val F1: 0.4090


Epoch 85: 100%|██████████| 7/7 [00:00<00:00,  9.91it/s]


Loss: 0.7979 | Val F1: 0.4002


Epoch 86: 100%|██████████| 7/7 [00:00<00:00, 11.25it/s]


Loss: 0.7964 | Val F1: 0.4042


Epoch 87: 100%|██████████| 7/7 [00:00<00:00, 11.37it/s]


Loss: 0.7904 | Val F1: 0.3999


Epoch 88: 100%|██████████| 7/7 [00:00<00:00, 11.32it/s]


Loss: 0.7926 | Val F1: 0.3988


Epoch 89: 100%|██████████| 7/7 [00:00<00:00, 11.11it/s]


Loss: 0.8038 | Val F1: 0.3967


Epoch 90: 100%|██████████| 7/7 [00:00<00:00, 10.50it/s]


Loss: 0.7945 | Val F1: 0.4077


Epoch 91: 100%|██████████| 7/7 [00:00<00:00, 10.43it/s]


Loss: 0.7894 | Val F1: 0.4081


Epoch 92: 100%|██████████| 7/7 [00:00<00:00, 11.14it/s]


Loss: 0.7993 | Val F1: 0.3991


Epoch 93: 100%|██████████| 7/7 [00:00<00:00, 10.99it/s]


Loss: 0.7877 | Val F1: 0.3958


Epoch 94: 100%|██████████| 7/7 [00:00<00:00, 10.78it/s]


Loss: 0.7872 | Val F1: 0.4025


Epoch 95: 100%|██████████| 7/7 [00:00<00:00, 11.14it/s]


Loss: 0.7853 | Val F1: 0.3993


Epoch 96: 100%|██████████| 7/7 [00:00<00:00,  9.64it/s]


Loss: 0.7898 | Val F1: 0.4001


Epoch 97: 100%|██████████| 7/7 [00:00<00:00, 11.01it/s]


Loss: 0.7876 | Val F1: 0.4062


Epoch 98: 100%|██████████| 7/7 [00:00<00:00, 10.36it/s]


Loss: 0.7823 | Val F1: 0.3933


Epoch 99: 100%|██████████| 7/7 [00:00<00:00, 10.28it/s]


Loss: 0.7749 | Val F1: 0.4026


Epoch 100: 100%|██████████| 7/7 [00:00<00:00, 10.69it/s]


Loss: 0.7793 | Val F1: 0.4043


In [9]:
# --- 7. FINAL TEST ---
print("\n--- FINAL TEST EVALUATION ---")
model.load_state_dict(torch.load("best_ner_cnn.pt"))
model.eval()
test_true, test_pred = [], []

with torch.no_grad():
    for x, y in test_loader:
        x = x.to(DEVICE)
        out = model(x)
        preds = torch.argmax(out, dim=2).cpu().numpy()
        labels = y.numpy()
        for i in range(len(x)):
            p_s, t_s = [], []
            for j in range(SEQUENCE_LENGTH):
                if labels[i][j] == 0:
                    break
                p_s.append(idx2tag[preds[i][j]])
                t_s.append(idx2tag[labels[i][j]])
            test_pred.append(p_s)
            test_true.append(t_s)

print(classification_report(test_true, test_pred))


--- FINAL TEST EVALUATION ---
              precision    recall  f1-score   support

        ANAT       0.26      0.09      0.13       364
        CHEM       0.33      0.21      0.25      1037
        DEVI       0.00      0.00      0.00       107
        DISO       0.26      0.28      0.27       977
        GEOG       0.55      0.10      0.16        63
        LIVB       0.67      0.52      0.59       498
        OBJC       0.18      0.07      0.10        81
        PHEN       0.50      0.04      0.08        70
        PHYS       0.50      0.15      0.23       190
        PROC       0.50      0.35      0.41       761

   micro avg       0.38      0.26      0.31      4148
   macro avg       0.37      0.18      0.22      4148
weighted avg       0.38      0.26      0.30      4148

