In [1]:
entity_label_set = ['O','Disease', 'SportsManager', 'Software', 'WrittenWork', 'Food', 'Scientist', 'OtherLOC', 'Cleric', 'Medication/Vaccine', 'PublicCorp', 'VisualWork', 'OtherPER', 'Artist', 'Symptom', 'SportsGRP', 'MedicalProcedure', 'Athlete', 'PrivateCorp', 'ORG', 'Politician', 'ArtWork', 'Drink', 'Vehicle', 'MusicalGRP', 'AnatomicalStructure', 'HumanSettlement', 'CarManufacturer', 'Facility', 'AerospaceManufacturer', 'OtherPROD', 'Clothing', 'MusicalWork', 'Station']

In [2]:
import os 
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

In [3]:
import torch
import torch.nn as nn
import numpy as np
from datasets import load_dataset
from torch.utils.data import DataLoader, Subset
from transformers import XLMRobertaTokenizerFast, XLMRobertaModel
import random

In [4]:

# Set random seed for reproducibility
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

# Load the MultiCoNER dataset (English)
dataset = load_dataset('MultiCoNER/multiconer_v2', 'English (EN)')
train_dataset = dataset['train']
validation_dataset = dataset['validation']
test_dataset = dataset['test']

# Select 100 random samples for low-resource setting
train_indices = random.sample(range(len(train_dataset)), 100)
train_subset = Subset(train_dataset, train_indices)

# Use the fast tokenizer
tokenizer = XLMRobertaTokenizerFast.from_pretrained('xlm-roberta-large')
encoder = XLMRobertaModel.from_pretrained('xlm-roberta-large')

In [5]:
# Define label sets
span_label_set = ['B', 'I', 'O']
# entity_label_set = ['Person', 'Location', 'Corporation', 'Groups', 'Product', 'CreativeWork', 'O']
span2id = {label: idx for idx, label in enumerate(span_label_set)}
entity2id = {label: idx for idx, label in enumerate(entity_label_set)}

def split_ner_tags(ner_tags):
    span_labels = []
    entity_labels = []
    for tag in ner_tags:
        if tag == "O":
            span_labels.append("O")
            entity_labels.append("O")
        else:
            bio, entity = tag.split("-", 1)
            span_labels.append(bio)
            entity_labels.append(entity)
    return span_labels, entity_labels

def collate_fn(batch):
    batch_tokens = [item['tokens'] for item in batch]
    batch_ner_tags = [item['ner_tags'] for item in batch]

    batch_span_labels = []
    batch_entity_labels = []
    for ner_tags in batch_ner_tags:
        span_labels, entity_labels = split_ner_tags(ner_tags)
        batch_span_labels.append([span2id[label] for label in span_labels])
        batch_entity_labels.append([entity2id[label] for label in entity_labels])

    encodings = tokenizer(batch_tokens, is_split_into_words=True, return_tensors='pt', padding=True, truncation=True, max_length=128)
    input_ids = encodings['input_ids']
    attention_mask = encodings['attention_mask']

    max_len = input_ids.size(1)
    padded_span_labels = []
    padded_entity_labels = []
    for idx, (span_labels, entity_labels) in enumerate(zip(batch_span_labels, batch_entity_labels)):
        word_ids = encodings.word_ids(batch_index=idx)
        aligned_span = []
        aligned_entity = []
        prev_word_id = None
        for word_id in word_ids:
            if word_id is None:
                aligned_span.append(-100)
                aligned_entity.append(-100)
            elif word_id != prev_word_id:
                aligned_span.append(span_labels[word_id])
                aligned_entity.append(entity_labels[word_id])
            else:
                aligned_span.append(-100)
                aligned_entity.append(-100)
            prev_word_id = word_id
        padded_span_labels.append(aligned_span)
        padded_entity_labels.append(aligned_entity)

    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'span_labels': torch.tensor(padded_span_labels),
        'entity_labels': torch.tensor(padded_entity_labels)
    }

# DataLoaders
batch_size = 8
train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
validation_loader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

In [6]:
# Model Definition
class E2DAEndogenous(nn.Module):
    def __init__(self, hidden_dim=1024, num_span_labels=len(span_label_set), num_entity_labels=len(entity_label_set)):
        super(E2DAEndogenous, self).__init__()
        self.encoder = encoder
        self.shared_extractor = nn.Linear(hidden_dim, hidden_dim)
        self.span_extractor = nn.Linear(hidden_dim, hidden_dim)
        self.type_extractor = nn.Linear(hidden_dim, hidden_dim)
        self.span_head = nn.Linear(hidden_dim, num_span_labels)
        self.type_head = nn.Linear(hidden_dim, num_entity_labels)
        self.dropout = nn.Dropout(0.1)

    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(input_ids, attention_mask=attention_mask)
        H = outputs.last_hidden_state

        H_share = self.shared_extractor(H)
        H_st = H - H_share
        H_span = self.span_extractor(H_st)
        H_type = self.type_extractor(H_st)
        H_sp = H_span + H_share
        H_tp = H_type + H_share

        span_logits = self.span_head(self.dropout(H_sp))
        type_logits = self.type_head(self.dropout(H_tp))
        return span_logits, type_logits, H_span, H_type, H_share

# Endogenous Augmentation Loss
def compute_covariance_matrix(features, labels, num_classes):
    batch_size, seq_len, dim = features.size()
    device = features.device  # Get the device of the input features
    features_flat = features.view(-1, dim)
    labels_flat = labels.view(-1)
    cov_matrices = []
    for c in range(num_classes):
        class_features = features_flat[labels_flat == c]
        if class_features.numel() > 0 and class_features.size(0) > 1:
            cov = torch.cov(class_features.T)
            cov_matrices.append(cov + torch.eye(dim, device=device) * 1e-6)
        else:
            cov_matrices.append(torch.eye(dim, device=device) * 1e-6)
    return torch.stack(cov_matrices)

def endogenous_loss(logits, labels, features, head_weights, head_bias, lambda_):
    batch_size, seq_len, num_classes = logits.size()
    cov_matrices = compute_covariance_matrix(features, labels, num_classes)

    loss = 0
    valid_tokens = 0
    for i in range(batch_size):
        for j in range(seq_len):
            if labels[i, j] != -100:
                c_i = labels[i, j]
                h_i = features[i, j]
                log_sum_exp = 0
                for c_j in range(num_classes):
                    if c_j != c_i:
                        delta_w = head_weights[c_j] - head_weights[c_i]
                        delta_b = head_bias[c_j] - head_bias[c_i]
                        mean_term = delta_w @ h_i + delta_b
                        var_term = (lambda_ / 2) * delta_w @ cov_matrices[c_i] @ delta_w
                        log_sum_exp += torch.exp(mean_term + var_term)
                loss += torch.log(1 + log_sum_exp)
                valid_tokens += 1
    return loss / valid_tokens if valid_tokens > 0 else torch.tensor(0.0, device=logits.device)

# Orthogonality Loss
def orthogonality_loss(H_span, H_type, H_share):
    dot1 = (H_span * H_share).sum(dim=-1).pow(2).mean()
    dot2 = (H_type * H_share).sum(dim=-1).pow(2).mean()
    dot3 = (H_span * H_type).sum(dim=-1).pow(2).mean()
    return dot1 + dot2 + dot3

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = E2DAEndogenous().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
model.load_state_dict(torch.load("best_e2da_endogenous.pth"))
model.to(device)

E2DAEndogenous(
  (encoder): XLMRobertaModel(
    (embeddings): XLMRobertaEmbeddings(
      (word_embeddings): Embedding(250002, 1024, padding_idx=1)
      (position_embeddings): Embedding(514, 1024, padding_idx=1)
      (token_type_embeddings): Embedding(1, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): XLMRobertaEncoder(
      (layer): ModuleList(
        (0-23): 24 x XLMRobertaLayer(
          (attention): XLMRobertaAttention(
            (self): XLMRobertaSdpaSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): XLMRobertaSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias

In [8]:
import torch
import numpy as np
from tqdm import tqdm
from sklearn.metrics import f1_score

alpha = 0.01
max_epochs = 435
patience = 8  # Stop training if no improvement for 'patience' epochs
best_train_loss = np.inf  # Track the best training loss
epochs_no_improve = 0  # Count epochs with no improvement

for epoch in range(max_epochs):
    model.train()
    total_loss = 0
    train_loader_tqdm = tqdm(train_loader, desc=f"Epoch {epoch+1}/{max_epochs}", leave=False)
    
    for batch in train_loader_tqdm:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        span_labels = batch['span_labels'].to(device)
        entity_labels = batch['entity_labels'].to(device)

        optimizer.zero_grad()
        span_logits, type_logits, H_span, H_type, H_share = model(input_ids, attention_mask)

        # Get classification head parameters
        span_weights = model.span_head.weight.detach()
        span_bias = model.span_head.bias.detach()
        type_weights = model.type_head.weight.detach()
        type_bias = model.type_head.bias.detach()

        # Dynamic lambda
        lambda_ = 1.5 * (epoch + 1) / max_epochs
        span_loss = endogenous_loss(span_logits, span_labels, H_span, span_weights, span_bias, lambda_)
        type_loss = endogenous_loss(type_logits, entity_labels, H_type, type_weights, type_bias, lambda_)
        ortho_loss = orthogonality_loss(H_span, H_type, H_share)

        loss = span_loss + type_loss + alpha * ortho_loss
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        train_loader_tqdm.set_postfix(loss=loss.item())

    avg_train_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{max_epochs}, Training Loss: {avg_train_loss:.4f}")

    # Early Stopping Check (Based on Training Loss)
    if avg_train_loss < best_train_loss:
        best_train_loss = avg_train_loss
        epochs_no_improve = 0
        torch.save(model.state_dict(), "best_e2da_endogenous.pth")  # Save best model
        print("Model saved!")
    else:
        epochs_no_improve += 1
        # print(f"No improvement for {epochs_no_improve} epoch(s)")

    if epochs_no_improve >= patience:
        print("Early stopping triggered. Stopping training.")
        break


print("Training complete.")


                                                                  

Epoch 1/435, Training Loss: 0.5209
Model saved!


                                                                  

Epoch 2/435, Training Loss: 0.2763
Model saved!


                                                                  

Epoch 3/435, Training Loss: 0.2199
Model saved!


                                                                  

Epoch 4/435, Training Loss: 0.2147
Model saved!


                                                                  

Epoch 5/435, Training Loss: 0.2067
Model saved!


                                                                  

Epoch 6/435, Training Loss: 0.1912
Model saved!


                                                                  

Epoch 7/435, Training Loss: 0.1960


                                                                  

Epoch 8/435, Training Loss: 0.1938


                                                                  

Epoch 9/435, Training Loss: 0.2026


                                                                  

Epoch 10/435, Training Loss: 0.1931


                                                                  

Epoch 11/435, Training Loss: 0.1969


                                                                  

Epoch 12/435, Training Loss: 0.1907
Model saved!


                                                                  

Epoch 13/435, Training Loss: 0.1807
Model saved!


                                                                  

Epoch 14/435, Training Loss: 0.1804
Model saved!


                                                                  

Epoch 15/435, Training Loss: 0.1763
Model saved!


                                                                  

Epoch 16/435, Training Loss: 0.1642
Model saved!


                                                                  

Epoch 17/435, Training Loss: 0.1639
Model saved!


                                                                  

Epoch 18/435, Training Loss: 0.1647


                                                                  

Epoch 19/435, Training Loss: 0.1685


                                                                  

Epoch 20/435, Training Loss: 0.1721


                                                                  

Epoch 21/435, Training Loss: 0.1622
Model saved!


                                                                  

Epoch 22/435, Training Loss: 0.1585
Model saved!


                                                                  

Epoch 23/435, Training Loss: 0.1609


                                                                  

Epoch 24/435, Training Loss: 0.1622


                                                                  

Epoch 25/435, Training Loss: 0.1552
Model saved!


                                                                  

Epoch 26/435, Training Loss: 0.1535
Model saved!


                                                                  

Epoch 27/435, Training Loss: 0.1487
Model saved!


                                                                  

Epoch 28/435, Training Loss: 0.1507


                                                                  

Epoch 29/435, Training Loss: 0.1508


                                                                  

Epoch 30/435, Training Loss: 0.1531


                                                                  

Epoch 31/435, Training Loss: 0.1476
Model saved!


                                                                  

Epoch 32/435, Training Loss: 0.1428
Model saved!


                                                                  

Epoch 33/435, Training Loss: 0.1435


                                                                  

Epoch 34/435, Training Loss: 0.1411
Model saved!


                                                                  

Epoch 35/435, Training Loss: 0.1450


                                                                  

Epoch 36/435, Training Loss: 0.1483


                                                                  

Epoch 37/435, Training Loss: 0.1482


                                                                  

Epoch 38/435, Training Loss: 0.1357
Model saved!


                                                                  

Epoch 39/435, Training Loss: 0.1343
Model saved!


                                                                  

Epoch 40/435, Training Loss: 0.1373


                                                                  

Epoch 41/435, Training Loss: 0.1305
Model saved!


                                                                  

Epoch 42/435, Training Loss: 0.1327


                                                                  

Epoch 43/435, Training Loss: 0.1459


                                                                  

Epoch 44/435, Training Loss: 0.1376


                                                                  

Epoch 45/435, Training Loss: 0.1300
Model saved!


                                                                  

Epoch 46/435, Training Loss: 0.1220
Model saved!


                                                                  

Epoch 47/435, Training Loss: 0.1317


                                                                  

Epoch 48/435, Training Loss: 0.1363


                                                                  

Epoch 49/435, Training Loss: 0.1270


                                                                  

Epoch 50/435, Training Loss: 0.1186
Model saved!


                                                                  

Epoch 51/435, Training Loss: 0.1143
Model saved!


                                                                  

Epoch 52/435, Training Loss: 0.1170


                                                                  

Epoch 53/435, Training Loss: 0.1380


                                                                  

Epoch 54/435, Training Loss: 0.1572


                                                                  

Epoch 55/435, Training Loss: 0.1423


                                                                  

Epoch 56/435, Training Loss: 0.1242


                                                                  

Epoch 57/435, Training Loss: 0.1128
Model saved!


                                                                  

Epoch 58/435, Training Loss: 0.1087
Model saved!


                                                                  

Epoch 59/435, Training Loss: 0.1084
Model saved!


                                                                  

Epoch 60/435, Training Loss: 0.1227


                                                                  

Epoch 61/435, Training Loss: 0.1183


                                                                  

Epoch 62/435, Training Loss: 0.1168


                                                                  

Epoch 63/435, Training Loss: 0.1036
Model saved!


                                                                  

Epoch 64/435, Training Loss: 0.1046


                                                                  

Epoch 65/435, Training Loss: 0.1044


                                                                  

Epoch 66/435, Training Loss: 0.1053


                                                                  

Epoch 67/435, Training Loss: 0.1018
Model saved!


                                                                  

Epoch 68/435, Training Loss: 0.1124


                                                                  

Epoch 69/435, Training Loss: 0.1065


                                                                  

Epoch 70/435, Training Loss: 0.1084


                                                                  

Epoch 71/435, Training Loss: 0.1032


                                                                  

Epoch 72/435, Training Loss: 0.1096


                                                                  

Epoch 73/435, Training Loss: 0.1002
Model saved!


                                                                  

Epoch 74/435, Training Loss: 0.0945
Model saved!


                                                                  

Epoch 75/435, Training Loss: 0.1014


                                                                  

Epoch 76/435, Training Loss: 0.1007


                                                                  

Epoch 77/435, Training Loss: 0.1007


                                                                  

Epoch 78/435, Training Loss: 0.1043


                                                                  

Epoch 79/435, Training Loss: 0.1049


                                                                  

Epoch 80/435, Training Loss: 0.0988


                                                                  

Epoch 81/435, Training Loss: 0.1090


                                                                  

Epoch 82/435, Training Loss: 0.1020
Early stopping triggered. Stopping training.
Training complete.




150 epochs already runned, model was loaded again for training

In [9]:

torch.save(model.state_dict(), "e2da_endogenous2.pth")
print("Model saved successfully.")
# model.load_state_dict(torch.load("e2da_endogenous.pth"))
# model.to(device)
# model.eval()  # Set the model to evaluation mode if using for inference


Model saved successfully.


In [10]:
from tqdm import tqdm
from sklearn.metrics import f1_score

# Test Evaluation with tqdm
model.eval()
test_preds, test_labels = [],[]

test_loader_tqdm = tqdm(test_loader, desc="Testing", leave=False)

with torch.no_grad():
    for batch in test_loader_tqdm:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        span_labels = batch['span_labels'].to(device)
        entity_labels = batch['entity_labels'].to(device)

        span_logits, type_logits, _, _, _ = model(input_ids, attention_mask)
        span_preds = torch.argmax(span_logits, dim=-1)
        type_preds = torch.argmax(type_logits, dim=-1)

        for i in range(span_preds.size(0)):
            for j in range(span_preds.size(1)):
                if span_labels[i, j] != -100:  # Ignore padding tokens
                    # Construct the predicted and true labels
                    pred_label = f"{span_label_set[span_preds[i, j]]}-{entity_label_set[type_preds[i, j]]}" if span_preds[i, j] != 2 else "O"
                    true_label = f"{span_label_set[span_labels[i, j]]}-{entity_label_set[entity_labels[i, j]]}" if span_labels[i, j] != 2 else "O"

                    # Only append non-"O" labels to the lists
                    if pred_label != "O" and true_label != "O":
                        test_preds.append(pred_label)
                        test_labels.append(true_label)

# Compute Micro-F1 Score for non-"O" predictions
micro_f1 = f1_score(test_labels, test_preds, average='micro')
print(f"Test Micro-F1 (non-'O'): {micro_f1:.4f}")


                                                                  

Test Micro-F1 (non-'O'): 0.5113
