# Transformer-CRF Fusion Models

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModel, AutoTokenizer
from torchcrf import CRF

In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tqdm

In [3]:
from torch import cuda
device = 'cuda' if cuda.is_available() else 'cpu'
# device='cpu'
print(device)

cuda


In [4]:
# Model evaluation
from transformers import EvalPrediction
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from seqeval.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.metrics import classification_report

In [5]:
from prepare_dataset import prepare_data

In [6]:
# training constants
MODEL_NAME = 'bert-base-uncased' # let's try bert first
MAX_LEN = 128
TRAIN_BATCH_SIZE = 16
VALID_BATCH_SIZE = 8
EPOCHS = 3
LEARNING_RATE = 1e-05
MAX_GRAD_NORM = 10

## Data Loader

In [7]:
class dataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_len):
        self.len = len(dataframe)
        self.data = dataframe
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __getitem__(self, index):
        sentence = self.data.sentence[index]
        word_labels = self.data.tags[index]

        encoding = self.tokenizer(sentence,
                                  is_split_into_words=True,
                                  padding='max_length',
                                  truncation=True,
                                  max_length=self.max_len)

        labels = [labels_to_ids[label] for label in word_labels]

        encoded_labels = np.ones(len(encoding["input_ids"]), dtype=int) * -100
        label_mask = np.zeros(len(encoding["input_ids"]), dtype=bool)
        word_ids = encoding.word_ids()

        previous_word_idx = None
        for idx, word_idx in enumerate(word_ids):
            if word_idx is None:
                continue
            elif word_idx != previous_word_idx:
                encoded_labels[idx] = labels[word_idx]
                label_mask[idx] = True  # mark this token as valid
                previous_word_idx = word_idx

        item = {key: torch.tensor(val) for key, val in encoding.items()}
        item['labels'] = torch.tensor(encoded_labels)
        label_mask[0] = True # force it starts with "on"
        item['label_mask'] = torch.tensor(label_mask)

        return item

    def __len__(self):
        return self.len

In [8]:
data_dict = prepare_data("../processed_notes.csv")
data = data_dict["data"]
labels_to_ids = data_dict["labels_to_ids"]
ids_to_labels = data_dict["ids_to_labels"]

In [9]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
NUM_LABELS = len(labels_to_ids)

In [10]:
train_size = 0.8
train_dataset = data.sample(frac=train_size,random_state=200)
test_dataset = data.drop(train_dataset.index).reset_index(drop=True)
train_dataset = train_dataset.reset_index(drop=True)

print("FULL Dataset: {}".format(data.shape))
print("TRAIN Dataset: {}".format(train_dataset.shape))
print("TEST Dataset: {}".format(test_dataset.shape))

training_set = dataset(train_dataset, tokenizer, MAX_LEN)
testing_set = dataset(test_dataset, tokenizer, MAX_LEN)

FULL Dataset: (20305, 4)
TRAIN Dataset: (16244, 4)
TEST Dataset: (4061, 4)


In [11]:
train_params = {'batch_size': TRAIN_BATCH_SIZE,
                'shuffle': True,
                'num_workers': 0
                }

test_params = {'batch_size': VALID_BATCH_SIZE,
                'shuffle': True,
                'num_workers': 0
                }

training_loader = DataLoader(training_set, **train_params)
testing_loader = DataLoader(testing_set, **test_params)

In [12]:
labels_to_ids

{'O': 0,
 'B-DISO': 1,
 'I-DISO': 2,
 'B-PROC': 3,
 'I-PROC': 4,
 'B-ANAT': 5,
 'I-ANAT': 6,
 'B-UNK': 7,
 'B-ACTI': 8,
 'I-ACTI': 9,
 'B-PHYS': 10,
 'I-PHYS': 11,
 'B-PHEN': 12,
 'I-PHEN': 13,
 'B-CONC': 14,
 'B-CHEM': 15,
 'I-CONC': 16,
 'B-OBJC': 17,
 'I-UNK': 18,
 'B-DEVI': 19,
 'I-DEVI': 20,
 'B-LIVB': 21,
 'I-LIVB': 22}

## Model Construction

In [13]:
class BERT_CRF(nn.Module):
    def __init__(self, bert_model_name, num_labels):
        super(BERT_CRF, self).__init__()
        self.bert = AutoModel.from_pretrained(bert_model_name)
        self.dropout = nn.Dropout(0.1)
        self.hidden2tag = nn.Linear(self.bert.config.hidden_size, num_labels)
        self.crf = CRF(num_labels, batch_first=True)

    def forward(self, input_ids, attention_mask, labels=None, label_mask=None):
        emissions = self.bert(input_ids=input_ids,
                              attention_mask=attention_mask)[0]
        emissions = self.dropout(emissions)
        emissions = self.hidden2tag(emissions)

        if labels is not None:
            labels = labels.clone()
            labels[labels == -100] = 0  # CRF will ignore these via label_mask
            loss = -self.crf(emissions, labels, mask=label_mask.bool(), reduction='mean')
            predictions = self.crf.decode(emissions, mask=label_mask.bool())
            return loss, predictions
        else:
            predictions = self.crf.decode(emissions, mask=label_mask.bool())
            return predictions


In [14]:
# Sanity Check
model = BERT_CRF(MODEL_NAME, num_labels=NUM_LABELS).to(device)
inputs = training_set[2]
input_ids = inputs["input_ids"].unsqueeze(0).to(device)
attention_mask = inputs["attention_mask"].unsqueeze(0).to(device)
labels = inputs["labels"].unsqueeze(0).to(device)
label_mask = inputs["label_mask"].unsqueeze(0).to(device)

input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
labels = labels.to(device)

outputs = model(input_ids, attention_mask=attention_mask, labels=labels, label_mask=label_mask)
initial_loss = outputs[0]
initial_loss

tensor(30.6187, device='cuda:0', grad_fn=<NegBackward0>)

In [15]:
# Initialization
model = BERT_CRF(MODEL_NAME, num_labels=NUM_LABELS).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-5)

In [18]:
def evaluate(model, eval_loader, device):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in eval_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            label_mask = batch["label_mask"].to(device)

            predictions = model(input_ids, attention_mask, labels=None, label_mask=label_mask)

            for pred_seq, label_seq, mask_seq in zip(predictions, labels, label_mask):
                label_seq = label_seq.cpu().numpy()
                mask_seq = mask_seq.cpu().numpy().astype(bool)
                mask_seq[0] = False

                # Extract only masked tokens
                true_label_ids = label_seq[mask_seq]
                pred_label_ids = pred_seq[:len(true_label_ids)]  # CRF may return exactly this many

                # Sanity check
                assert len(true_label_ids) == len(pred_label_ids), \
                    f"Mismatch in pred/true lengths: {len(pred_label_ids)} vs {len(true_label_ids)}"

                # Map to label strings
                all_preds.append([ids_to_labels[p] for p in pred_label_ids])
                all_labels.append([ids_to_labels[l] for l in true_label_ids])

    metrics = {
        "accuracy": accuracy_score(all_labels, all_preds),
        "precision": precision_score(all_labels, all_preds, average='macro'),
        "recall": recall_score(all_labels, all_preds, average='macro'),
        "f1": f1_score(all_labels, all_preds, average='macro'),
        "f1_all": f1_score(all_labels, all_preds, average=None),
    }
    return metrics

In [None]:
log_every = 200   # log every 50 batches
eval_every = 200 # evaluate every 200 batches
num_epochs = EPOCHS
log_history = []  # To store logged metrics
global_step = 0

for epoch in range(EPOCHS):
    model.train()
    epoch_loss = 0.0
    num_batches = len(training_loader)
    
    for step, batch in enumerate(tqdm.tqdm(training_loader, desc=f"Epoch {epoch+1}/{num_epochs}")):
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        label_mask = batch['label_mask'].to(device)

        loss, _ = model(input_ids, attention_mask, labels, label_mask=label_mask)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        global_step += 1

        # Log training loss every 'log_every' steps
        if (step + 1) % log_every == 0:
            avg_loss = epoch_loss / (step + 1)
            log_history.append({
                "step": global_step,
                "loss": avg_loss
            })
            print(f"Step {global_step}: Avg Train Loss = {avg_loss:.4f}")

        # Run evaluation every 'eval_every' steps
        if (step + 1) % eval_every == 0:
            eval_metrics = evaluate(model, testing_loader, device)
            log_history.append({
                "step": global_step,
                "eval_f1": eval_metrics["f1"],
                "eval_accuracy": eval_metrics["accuracy"],
                "eval_precision": eval_metrics["precision"],
                "eval_recall": eval_metrics["recall"],
                "eval_loss": eval_metrics.get("eval_loss", None)  # if you compute eval loss
            })
            print(f"Step {global_step}: F1 = {eval_metrics['f1']:.4f}, Acc = {eval_metrics['accuracy']:.4f}, Prec = {eval_metrics['precision']:.4f}, Recall = {eval_metrics['recall']:.4f}")

    # End-of-epoch logging
    avg_epoch_loss = epoch_loss / num_batches
    print(f"Epoch {epoch+1} Completed: Avg Training Loss = {avg_epoch_loss:.4f}")
    eval_metrics = evaluate(model, testing_loader, device)
    print(f"Final: F1 = {eval_metrics['f1']:.4f}, Acc = {eval_metrics['accuracy']:.4f}, Prec = {eval_metrics['precision']:.4f}, Recall = {eval_metrics['recall']:.4f}")
    print(eval_metrics["f1_all"])


Epoch 1/3:  20%|█▉        | 199/1016 [00:36<02:29,  5.46it/s]

Step 200: Avg Train Loss = 5.8431


  _warn_prf(average, modifier, msg_start, len(result))
Epoch 1/3:  20%|█▉        | 201/1016 [00:50<41:33,  3.06s/it]

Step 200: F1 = 0.0139, Acc = 0.6925, Prec = 0.0125, Recall = 0.0161


Epoch 1/3:  39%|███▉      | 399/1016 [01:26<01:51,  5.53it/s]

Step 400: Avg Train Loss = 5.2085


Epoch 1/3:  39%|███▉      | 401/1016 [01:40<31:11,  3.04s/it]

Step 400: F1 = 0.0156, Acc = 0.7038, Prec = 0.0140, Recall = 0.0177


Epoch 1/3:  59%|█████▉    | 599/1016 [02:16<01:15,  5.53it/s]

Step 600: Avg Train Loss = 4.9233


Epoch 1/3:  59%|█████▉    | 601/1016 [02:30<21:04,  3.05s/it]

Step 600: F1 = 0.0158, Acc = 0.6956, Prec = 0.0139, Recall = 0.0185


Epoch 1/3:  79%|███████▊  | 799/1016 [03:06<00:39,  5.51it/s]

Step 800: Avg Train Loss = 4.7138


Epoch 1/3:  79%|███████▉  | 801/1016 [03:20<10:57,  3.06s/it]

Step 800: F1 = 0.0165, Acc = 0.6874, Prec = 0.0141, Recall = 0.0200


Epoch 1/3:  98%|█████████▊| 999/1016 [03:56<00:03,  5.49it/s]

Step 1000: Avg Train Loss = 4.5498


Epoch 1/3:  99%|█████████▊| 1001/1016 [04:10<00:45,  3.06s/it]

Step 1000: F1 = 0.0155, Acc = 0.6895, Prec = 0.0136, Recall = 0.0179


Epoch 1/3: 100%|██████████| 1016/1016 [04:13<00:00,  4.01it/s]


Epoch 1 Completed: Avg Training Loss = 4.5415
Final: F1 = 0.0169, Acc = 0.6893, Prec = 0.0143, Recall = 0.0207
[0.         0.04231166 0.         0.         0.09140738 0.
 0.         0.         0.03548681 0.        ]


Epoch 2/3:  20%|█▉        | 199/1016 [00:36<02:30,  5.44it/s]

Step 1216: Avg Train Loss = 3.0506


Epoch 2/3:  20%|█▉        | 201/1016 [00:50<41:34,  3.06s/it]

Step 1216: F1 = 0.0170, Acc = 0.6689, Prec = 0.0141, Recall = 0.0214


Epoch 2/3:  39%|███▉      | 399/1016 [01:26<01:51,  5.52it/s]

Step 1416: Avg Train Loss = 2.7843


Epoch 2/3:  39%|███▉      | 401/1016 [01:40<31:16,  3.05s/it]

Step 1416: F1 = 0.0155, Acc = 0.6777, Prec = 0.0135, Recall = 0.0185


Epoch 2/3:  59%|█████▉    | 599/1016 [02:16<01:15,  5.49it/s]

Step 1616: Avg Train Loss = 2.8009


Epoch 2/3:  59%|█████▉    | 601/1016 [02:30<21:09,  3.06s/it]

Step 1616: F1 = 0.0154, Acc = 0.6892, Prec = 0.0137, Recall = 0.0176


Epoch 2/3:  79%|███████▊  | 799/1016 [03:06<00:39,  5.51it/s]

Step 1816: Avg Train Loss = 2.7403


Epoch 2/3:  79%|███████▉  | 801/1016 [03:20<10:58,  3.06s/it]

Step 1816: F1 = 0.0166, Acc = 0.6774, Prec = 0.0143, Recall = 0.0200


Epoch 2/3:  98%|█████████▊| 999/1016 [03:56<00:03,  5.53it/s]

Step 2016: Avg Train Loss = 2.7224


In [None]:
def plot_eval():
    train_steps, train_loss = [], []
    eval_steps, eval_loss, f1s = [], [], []
    
    for entry in log_history:
        if "loss" in entry and "step" in entry:
            train_steps.append(entry["step"])
            train_loss.append(entry["loss"])
        if "eval_f1" in entry:
            eval_steps.append(entry["step"])
            eval_loss.append(entry.get("eval_loss", 0))  # or None if not computed
            f1s.append(entry["eval_f1"])
    
    plt.figure(figsize=(12, 5))
    
    # Plot losses
    plt.subplot(1, 2, 1)
    if train_loss:
        plt.plot(train_steps, train_loss, label="Train Loss")
    if eval_loss:
        plt.plot(eval_steps, eval_loss, label="Eval Loss")
    plt.xlabel("Steps")
    plt.ylabel("Loss")
    plt.title("Training Loss")
    plt.legend()
    
    # Plot F1
    plt.subplot(1, 2, 2)
    if f1s[0] is not None:
        plt.plot(eval_steps, f1s, label="Eval F1", color='green')
        plt.xlabel("Steps")
        plt.ylabel("F1 Score")
        plt.title("Evaluation F1 over Time")
        plt.legend()
    
    plt.tight_layout()

In [None]:
plot_eval()