## Cấu hình

In [4]:
import json
import os
import pickle
import logging
import sys
import random
import argparse
from tqdm import tqdm

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from transformers import AutoTokenizer, AutoModel
from huggingface_hub import login

In [5]:
class Config:
    # Dataset configs
    DATASETS = ['WEBNLG', 'NYT']
    DATA_PATH = "./data_care"

    # Model configs
    HIDDEN_SIZE = 768
    SHARE_HIDDEN_SIZE = 128
    DIST_EMB_SIZE = 20
    CO_ATTENTION_LAYERS = 3

    # Training configs
    EPOCHS = 10
    BATCH_SIZE = 8
    EVAL_BATCH_SIZE = 8
    # LR = 0.0001 # webnlg
    LR = 0.00001
    WEIGHT_DECAY = 0
    DROPOUT = 0.1
    CLIP = 0.5
    SEED = 0

    # Other configs
    EVAL_METRIC = "micro"  # or "macro"
    MAX_SEQ_LEN = 128

    # HF Token (replace with your token)
    HF_TOKEN = ""

config = Config()

## helper

In [6]:
def setup_logging(output_dir):
    """Setup logging configuration"""
    os.makedirs(output_dir, exist_ok=True)

    logger = logging.getLogger("CARELogger")
    logger.setLevel(logging.INFO)
    logger.handlers.clear()

    # File handler
    log_file = os.path.join(output_dir, "train_output.log")
    file_handler = logging.FileHandler(log_file, mode='w')
    file_handler.setLevel(logging.INFO)

    # Stream handler
    stream_handler = logging.StreamHandler(sys.stdout)
    stream_handler.setLevel(logging.INFO)

    # Formatter
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    file_handler.setFormatter(formatter)
    stream_handler.setFormatter(formatter)

    logger.addHandler(file_handler)
    logger.addHandler(stream_handler)

    return logger

def set_seed(seed):
    """Set random seeds for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def load_json(path, filename):
    """Load JSON file"""
    filepath = os.path.join(path, filename)
    with open(filepath, 'r', encoding='utf-8') as f:
        return json.load(f)


## Data Processing

#### Tiền xử lý dữ liệu

1. Đọc dữ liệu thô từ file JSON
2. Tiền xử lý từng mẫu dữ liệu:
* Tách câu thành các từ.
* Chuyển các thực thể và quan hệ thành nhãn chỉ số.
* Chuyển đổi nhãn từ cấp độ từ sang cấp độ token BERT.
* Tạo ma trận khoảng cách (distance matrix) cho từng cặp token.
3. Lưu dữ liệu đã xử lý vào các đối tượng Dataset.
4. Tạo DataLoader để phục vụ huấn luyện và đánh giá.

In [21]:
class Collator:
    def __init__(self, ner2idx, rel2idx):
        self.ner2idx = ner2idx
        self.rel2idx = rel2idx

    def __call__(self, batch):
        texts = [item[0] for item in batch]
        ner_labels = [item[1] for item in batch]
        rc_labels = [item[2] for item in batch]
        bert_lens = [item[3] for item in batch]
        dists = [item[4] for item in batch]

        max_len = max(bert_lens)
        batch_size = len(texts)

        # Pad distance matrices
        padded_dists = []
        for dist in dists:
            current_len = dist.size(0)
            if current_len < max_len:
                pad_size = max_len - current_len
                padded_dist = F.pad(dist, (0, pad_size, 0, pad_size), value=19)
            else:
                padded_dist = dist[:max_len, :max_len]
            padded_dists.append(padded_dist)

        dist_tensor = torch.stack(padded_dists)

        # Generate label tensors
        ner_tensors = []
        re_tensors = []

        for ner_list, rc_list in zip(ner_labels, rc_labels):
            # NER labels
            ner_tensor = torch.zeros(max_len, max_len, len(self.ner2idx))
            for i in range(0, len(ner_list), 3):
                if i + 2 < len(ner_list):
                    head, tail, ner_type = ner_list[i], ner_list[i+1], ner_list[i+2]
                    if head < max_len and tail < max_len and ner_type in self.ner2idx:
                        ner_tensor[head, tail, self.ner2idx[ner_type]] = 1
            ner_tensors.append(ner_tensor)

            # Relation labels
            re_tensor = torch.zeros(max_len, max_len, len(self.rel2idx))
            for i in range(0, len(rc_list), 3):
                if i + 2 < len(rc_list):
                    subj, obj, rel_type = rc_list[i], rc_list[i+1], rc_list[i+2]
                    if subj < max_len and obj < max_len and rel_type in self.rel2idx:
                        re_tensor[subj, obj, self.rel2idx[rel_type]] = 1
            re_tensors.append(re_tensor)

        ner_labels_tensor = torch.stack(ner_tensors).permute(1, 2, 0, 3)
        re_labels_tensor = torch.stack(re_tensors).permute(1, 2, 0, 3)

        # Create mask
        mask = torch.zeros(max_len, batch_size)
        for i, length in enumerate(bert_lens):
            mask[:length, i] = 1

        return texts, ner_labels_tensor, re_labels_tensor, mask, dist_tensor


#### class Collator:
* Collator dùng để gom các mẫu riêng lẻ thành một batch cho DataLoader.
* Tác vụ chính:
    * Pad các ma trận khoảng cách về cùng kích thước (max_len).
    * Tạo tensor nhãn NER và quan hệ cho từng batch (dạng one-hot).
    * Tạo mask để đánh dấu các vị trí hợp lệ (không phải padding).
* Đầu ra: texts, tensor nhãn NER, tensor nhãn quan hệ, mask, tensor khoảng cách

In [22]:
# Distance mapping
class CAREDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

def map_word_to_bert(words, tokenizer):
    """Map original words to BERT token positions"""
    mapping = {}
    current_idx = 0
    for word_idx, word in enumerate(words):
        bert_tokens = tokenizer.tokenize(word)
        word_len = len(bert_tokens)
        mapping[word_idx] = [current_idx, current_idx + word_len - 1]
        current_idx += word_len
    return mapping

def transform_labels(labels, word_to_bert, label_type='ner'):
    """Transform word-level labels to BERT token-level labels"""
    new_labels = []
    for i in range(0, len(labels), 3):
        if label_type == 'ner':
            # +1 for [CLS] token
            start = word_to_bert[labels[i]][0] + 1
            end = word_to_bert[labels[i + 1]][0] + 1
        else:  # relation
            start = word_to_bert[labels[i]][0] + 1
            end = word_to_bert[labels[i + 1]][0] + 1
        new_labels.extend([start, end, labels[i + 2]])
    return new_labels


#### class CAREDataset
* CAREDataset kế thừa từ torch.utils.data.Dataset.

* Lưu trữ danh sách các mẫu dữ liệu đã được tiền xử lý.

* Cung cấp hai phương thức chính:
    * __len__: trả về số lượng mẫu.
    * __getitem__: trả về một mẫu dữ liệu tại chỉ số idx (bao gồm: text, nhãn NER, nhãn quan hệ, độ dài, ma trận khoảng cách).


In [23]:
dis2idx = np.zeros(4000, dtype='int64')
dis2idx[1] = 1
dis2idx[2:] = 2
dis2idx[4:] = 3
dis2idx[8:] = 4
dis2idx[16:] = 5
dis2idx[32:] = 6
dis2idx[64:] = 7
dis2idx[128:] = 8
dis2idx[256:] = 9

def preprocess_data(data, tokenizer, dataset_name):
    """Preprocess data for both NYT and WEBNLG datasets"""
    processed = []

    for item in tqdm(data, desc=f"Processing {dataset_name}"):
        text = item['text'].split()
        ner_labels = []
        rc_labels = []

        for triple in item['triple_list']:
            subj, rel, obj = triple
            try:
                subj_idx = text.index(subj)
                obj_idx = text.index(obj)

                # Add entities if not already present
                if subj_idx not in [ner_labels[i] for i in range(0, len(ner_labels), 3)]:
                    ner_labels.extend([subj_idx, subj_idx, "None"])
                if obj_idx not in [ner_labels[i] for i in range(0, len(ner_labels), 3)]:
                    ner_labels.extend([obj_idx, obj_idx, "None"])

                rc_labels.extend([subj_idx, obj_idx, rel])
            except ValueError:
                continue  # Skip if entity not found in text

        # Truncate if too long
        if len(text) > config.MAX_SEQ_LEN:
            text = text[:config.MAX_SEQ_LEN]
            # Filter labels to keep only those within truncated text
            ner_labels = [l for i, l in enumerate(ner_labels)
                         if i % 3 == 2 or l < config.MAX_SEQ_LEN]
            rc_labels = [l for i, l in enumerate(rc_labels)
                        if i % 3 == 2 or l < config.MAX_SEQ_LEN]

        # Convert to BERT format
        sent_str = ' '.join(text)
        bert_tokens = tokenizer.tokenize(sent_str)
        bert_len = len(bert_tokens) + 2  # +2 for [CLS] and [SEP]

        word_to_bert = map_word_to_bert(text, tokenizer)
        ner_labels = transform_labels(ner_labels, word_to_bert, 'ner')
        rc_labels = transform_labels(rc_labels, word_to_bert, 'relation')

        # Create distance matrix
        dist_inputs = torch.zeros((bert_len, bert_len), dtype=torch.long)
        for k in range(bert_len):
            dist_inputs[k, :] += k
            dist_inputs[:, k] -= k

        for i in range(bert_len):
            for j in range(bert_len):
                if dist_inputs[i, j] < 0:
                    dist_inputs[i, j] = int(dis2idx[-dist_inputs[i, j]]) + 9
                else:
                    dist_inputs[i, j] = int(dis2idx[dist_inputs[i, j]])
        dist_inputs[dist_inputs == 0] = 19

        processed.append((text, ner_labels, rc_labels, bert_len, dist_inputs))

    return processed

#### preprocess_data
Nhận vào dữ liệu thô, tokenizer và tên dataset.

Với mỗi mẫu:

* Tách từ.
* Tạo nhãn thực thể (ner_labels) và quan hệ (rc_labels) ở cấp độ từ.
* Chuyển đổi nhãn sang cấp độ token BERT.
* Tạo ma trận khoảng cách giữa các token.

Trả về danh sách các mẫu đã xử lý, mỗi mẫu là một list gồm: text, nhãn NER, nhãn quan hệ, độ dài, ma trận khoảng cách.

In [24]:
def create_dataloaders(dataset_name):
    """Create dataloaders for specified dataset"""
    data_path = os.path.join(config.DATA_PATH, dataset_name)

    # Load label mappings
    with open(os.path.join(data_path, "ner2idx.json"), "r") as f:
        ner2idx = json.load(f)
    with open(os.path.join(data_path, "rel2idx.json"), "r") as f:
        rel2idx = json.load(f)

    # Load datasets
    train_data = load_json(data_path, 'train_triples.json')
    test_data = load_json(data_path, 'test_triples.json')
    dev_data = load_json(data_path, 'dev_triples.json')

    # Initialize tokenizer
    tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

    # Process datasets
    train_processed = preprocess_data(train_data, tokenizer, dataset_name)
    test_processed = preprocess_data(test_data, tokenizer, dataset_name)
    dev_processed = preprocess_data(dev_data, tokenizer, dataset_name)

    # Create datasets
    train_dataset = CAREDataset(train_processed)
    test_dataset = CAREDataset(test_processed)
    dev_dataset = CAREDataset(dev_processed)

    # Create collator
    collate_fn = Collator(ner2idx, rel2idx)

    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE,
                             shuffle=True, collate_fn=collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=config.EVAL_BATCH_SIZE,
                            shuffle=False, collate_fn=collate_fn)
    dev_loader = DataLoader(dev_dataset, batch_size=config.EVAL_BATCH_SIZE,
                           shuffle=False, collate_fn=collate_fn)

    return train_loader, test_loader, dev_loader, ner2idx, rel2idx


In [36]:
data_path = './data_care/WEBNLG'
train_data_weblg = json.load(open(os.path.join(data_path, 'train_triples.json')))
test_data_weblg = json.load(open(os.path.join(data_path, 'test_triples.json')))
dev_data_weblg = json.load(open(os.path.join(data_path, 'dev_triples.json')))

print("WEBNLG Datasets:\n")
print(len(train_data_weblg))
print(len(test_data_weblg))
print(len(dev_data_weblg))

print()

data_path = './data_care/NYT'

train_data_nyt = json.load(open(os.path.join(data_path, 'train_triples.json')))
test_data_nyt = json.load(open(os.path.join(data_path, 'test_triples.json')))
dev_data_nyt = json.load(open(os.path.join(data_path, 'dev_triples.json')))

print("NYT Datasets:\n")
print(len(train_data_nyt))
print(len(test_data_nyt))
print(len(dev_data_nyt))

WEBNLG Datasets:

5019
703
499

NYT Datasets:

56195
5000
4999


In [37]:
print(train_data_weblg[0])
print(train_data_nyt[0])

{'text': 'Alan Bean , who was part of Apollo 12 , was born in Wheeler , Texas on March 15th , 1932 and is now retired .', 'triple_list': [['Bean', 'was a crew member of', '12'], ['Bean', 'birthPlace', 'Texas']]}
{'text': 'Massachusetts ASTON MAGNA Great Barrington ; also at Bard College , Annandale-on-Hudson , N.Y. , July 1-Aug .', 'triple_list': [['Annandale-on-Hudson', '/location/location/contains', 'College']]}


In [45]:
# 2. Tách từ và ánh xạ thực thể
words = train_data_weblg[0]["text"].split()
print("Bước 2 - Tách từ:", words)

Bước 2 - Tách từ: ['Alan', 'Bean', ',', 'who', 'was', 'part', 'of', 'Apollo', '12', ',', 'was', 'born', 'in', 'Wheeler', ',', 'Texas', 'on', 'March', '15th', ',', '1932', 'and', 'is', 'now', 'retired', '.']


In [38]:
words = train_data_weblg[0]['text'].split()

ner_labels = []
rc_labels = []
for triple in train_data_weblg[0]["triple_list"]:
    subj, rel, obj = triple
    try:
        subj_idx = words.index(subj.split()[0])  # Lấy vị trí từ đầu tiên
        obj_idx = words.index(obj)
        ner_labels.extend([subj_idx, subj_idx+1, "None"])  # Barack Obama: 0-1
        ner_labels.extend([obj_idx, obj_idx, "None"])      # Hawaii: 5
        rc_labels.extend([subj_idx, obj_idx, rel])
    except ValueError:
        continue
print("Bước 3 - NER labels:", ner_labels)
print("Bước 3 - RE labels:", rc_labels)

Bước 3 - NER labels: [1, 2, 'None', 8, 8, 'None', 1, 2, 'None', 15, 15, 'None']
Bước 3 - RE labels: [1, 8, 'was a crew member of', 1, 15, 'birthPlace']


In [39]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
bert_tokens = tokenizer.tokenize(train_data_weblg[0]['text'])
print("Bước 4 - BERT tokens:", bert_tokens)

Bước 4 - BERT tokens: ['Alan', 'Bean', ',', 'who', 'was', 'part', 'of', 'Apollo', '12', ',', 'was', 'born', 'in', 'Wheeler', ',', 'Texas', 'on', 'March', '15th', ',', '1932', 'and', 'is', 'now', 'retired', '.']


In [40]:
word_to_bert = map_word_to_bert(words, tokenizer)
print("Bước 5 - Ánh xạ từ->BERT:", word_to_bert)

Bước 5 - Ánh xạ từ->BERT: {0: [0, 0], 1: [1, 1], 2: [2, 2], 3: [3, 3], 4: [4, 4], 5: [5, 5], 6: [6, 6], 7: [7, 7], 8: [8, 8], 9: [9, 9], 10: [10, 10], 11: [11, 11], 12: [12, 12], 13: [13, 13], 14: [14, 14], 15: [15, 15], 16: [16, 16], 17: [17, 17], 18: [18, 18], 19: [19, 19], 20: [20, 20], 21: [21, 21], 22: [22, 22], 23: [23, 23], 24: [24, 24], 25: [25, 25]}


In [41]:

ner_labels_bert = transform_labels(ner_labels, word_to_bert, 'ner')
rc_labels_bert = transform_labels(rc_labels, word_to_bert, 'relation')
print("Bước 6 - NER labels (BERT):", ner_labels_bert)
print("Bước 6 - RE labels (BERT):", rc_labels_bert)


Bước 6 - NER labels (BERT): [1, 2, 'None', 8, 8, 'None', 1, 2, 'None', 15, 15, 'None']
Bước 6 - RE labels (BERT): [1, 8, 'was a crew member of', 1, 15, 'birthPlace']


In [42]:
def build_dis2idx():
    dis2idx = np.zeros(4000, dtype='int64')
    dis2idx[1] = 1
    dis2idx[2:] = 2
    dis2idx[4:] = 3
    dis2idx[8:] = 4
    dis2idx[16:] = 5
    dis2idx[32:] = 6
    dis2idx[64:] = 7
    dis2idx[128:] = 8
    dis2idx[256:] = 9
    return dis2idx
dis2idx = build_dis2idx()
bert_len = len(bert_tokens) + 2  # +2 cho [CLS], [SEP]
dist_inputs = torch.zeros((bert_len, bert_len), dtype=torch.long)
for k in range(bert_len):
    dist_inputs[k, :] += k
    dist_inputs[:, k] -= k
for i in range(bert_len):
    for j in range(bert_len):
        if dist_inputs[i, j] < 0:
            dist_inputs[i, j] = int(dis2idx[-dist_inputs[i, j]]) + 9
        else:
            dist_inputs[i, j] = int(dis2idx[dist_inputs[i, j]])
dist_inputs[dist_inputs == 0] = 19
print("Bước 7 - Ma trận khoảng cách:")
print(dist_inputs)

Bước 7 - Ma trận khoảng cách:
tensor([[19, 10, 11, 11, 12, 12, 12, 12, 13, 13, 13, 13, 13, 13, 13, 13, 14, 14,
         14, 14, 14, 14, 14, 14, 14, 14, 14, 14],
        [ 1, 19, 10, 11, 11, 12, 12, 12, 12, 13, 13, 13, 13, 13, 13, 13, 13, 14,
         14, 14, 14, 14, 14, 14, 14, 14, 14, 14],
        [ 2,  1, 19, 10, 11, 11, 12, 12, 12, 12, 13, 13, 13, 13, 13, 13, 13, 13,
         14, 14, 14, 14, 14, 14, 14, 14, 14, 14],
        [ 2,  2,  1, 19, 10, 11, 11, 12, 12, 12, 12, 13, 13, 13, 13, 13, 13, 13,
         13, 14, 14, 14, 14, 14, 14, 14, 14, 14],
        [ 3,  2,  2,  1, 19, 10, 11, 11, 12, 12, 12, 12, 13, 13, 13, 13, 13, 13,
         13, 13, 14, 14, 14, 14, 14, 14, 14, 14],
        [ 3,  3,  2,  2,  1, 19, 10, 11, 11, 12, 12, 12, 12, 13, 13, 13, 13, 13,
         13, 13, 13, 14, 14, 14, 14, 14, 14, 14],
        [ 3,  3,  3,  2,  2,  1, 19, 10, 11, 11, 12, 12, 12, 12, 13, 13, 13, 13,
         13, 13, 13, 13, 14, 14, 14, 14, 14, 14],
        [ 3,  3,  3,  3,  2,  2,  1, 19, 10, 11, 11, 

In [43]:
# 8. Đóng gói thành mẫu dữ liệu
sample = [words, ner_labels_bert, rc_labels_bert, bert_len, dist_inputs]
print("\nBước 8 - Mẫu dữ liệu hoàn chỉnh:")
print("words:", sample[0])
print("ner_labels_bert:", sample[1])
print("rc_labels_bert:", sample[2])
print("bert_len:", sample[3])
print("dist_inputs shape:", sample[4].shape)


Bước 8 - Mẫu dữ liệu hoàn chỉnh:
words: ['Alan', 'Bean', ',', 'who', 'was', 'part', 'of', 'Apollo', '12', ',', 'was', 'born', 'in', 'Wheeler', ',', 'Texas', 'on', 'March', '15th', ',', '1932', 'and', 'is', 'now', 'retired', '.']
ner_labels_bert: [1, 2, 'None', 8, 8, 'None', 1, 2, 'None', 15, 15, 'None']
rc_labels_bert: [1, 8, 'was a crew member of', 1, 15, 'birthPlace']
bert_len: 28
dist_inputs shape: torch.Size([28, 28])


#### preprocess_data
Nhận vào tên dataset (NYT hoặc WEBNLG).

Đọc file nhãn, dữ liệu train/test/dev.

Khởi tạo tokenizer.

Tiền xử lý dữ liệu với preprocess_data.

Tạo các CAREDataset cho từng tập.

Tạo Collator để gom batch.

Tạo DataLoader cho train, test, dev.

Trả về các DataLoader và mapping nhãn.

In [48]:
train_loader_weblg, test_loader_weblg, dev_loader_weblg, ner2idx_weblg, rel2idx_weblg = create_dataloaders('WEBNLG')

print()
print("Train loader size:", len(train_loader_weblg))
print("Test loader size:", len(test_loader_weblg))
print("Dev loader size:", len(dev_loader_weblg))

Processing WEBNLG: 100%|██████████| 5019/5019 [02:17<00:00, 36.44it/s] 
Processing WEBNLG: 100%|██████████| 703/703 [00:24<00:00, 28.90it/s] 
Processing WEBNLG: 100%|██████████| 499/499 [00:13<00:00, 36.32it/s] 


Train loader size: 628
Test loader size: 88
Dev loader size: 63





In [None]:
print("NER2IDX:", ner2idx_weblg)
print("REL2IDX:", rel2idx_weblg)
print("Train loader:", train_loader_weblg)
print("Test loader:", test_loader_weblg)
print("Dev loader:", dev_loader_weblg)

NER2IDX: {'None': 0}
REL2IDX: {'firstPublicationYear': 0, 'EISSN_number': 1, 'has to its northeast': 2, 'aircraftFighter': 3, 'elevationAboveTheSeaLevel': 4, 'influencedBy': 5, 'ethnicGroup': 6, 'chief': 7, 'fossil': 8, 'served as Chief of the Astronaut Office in': 9, 'class': 10, 'aircraftHelicopter': 11, 'backup pilot': 12, 'league': 13, '4th_runway_LengthFeet': 14, 'battles': 15, 'foundationPlace': 16, 'occupation': 17, 'leaderName': 18, 'headquarter': 19, 'jurisdiction': 20, 'numberOfMembers': 21, 'champions': 22, 'has to its southwest': 23, 'leaderTitle': 24, 'precededBy': 25, 'has to its northwest': 26, 'buildingType': 27, 'river': 28, 'architect': 29, 'neighboringMunicipality': 30, 'followedBy': 31, 'nearestCity': 32, 'editor': 33, 'officialLanguage': 34, 'chairman': 35, 'has to its southeast': 36, 'academicDiscipline': 37, 'tenant': 38, 'largestCity': 39, 'LCCN_number': 40, 'numberOfRooms': 41, 'was a crew member of': 42, 'administrativeArrondissement': 43, 'architecturalStyle'

In [None]:
train_loader_nyt, test_loader_nyt, dev_loader_nyt, ner2idx_nyt, rel2idx_nyt = create_dataloaders('NYT')

print()
print("Train loader size:", len(train_loader_nyt))
print("Test loader size:", len(test_loader_nyt))
print("Dev loader size:", len(dev_loader_nyt))

Processing NYT: 100%|██████████| 56195/56195 [1:10:51<00:00, 13.22it/s]
Processing NYT: 100%|██████████| 5000/5000 [05:53<00:00, 14.14it/s]
Processing NYT: 100%|██████████| 4999/4999 [06:13<00:00, 13.39it/s]


Train loader size: 7025
Test loader size: 625
Dev loader size: 625





In [None]:
print("NER2IDX:", ner2idx_nyt)
print("REL2IDX:", rel2idx_nyt)
print("Train loader:", train_loader_nyt)
print("Test loader:", test_loader_nyt)
print("Dev loader:", dev_loader_nyt)

NER2IDX: {'None': 0}
REL2IDX: {'/location/country/administrative_divisions': 0, '/sports/sports_team_location/teams': 1, '/people/ethnicity/geographic_distribution': 2, '/people/person/nationality': 3, '/people/person/ethnicity': 4, '/business/company/major_shareholders': 5, '/people/person/place_lived': 6, '/location/country/capital': 7, '/people/person/children': 8, '/business/company/industry': 9, '/business/company_shareholder/major_shareholder_of': 10, '/business/person/company': 11, '/people/ethnicity/people': 12, '/people/person/religion': 13, '/location/neighborhood/neighborhood_of': 14, '/people/deceased_person/place_of_death': 15, '/people/person/place_of_birth': 16, '/business/company/place_founded': 17, '/location/administrative_division/country': 18, '/location/location/contains': 19, '/business/company/advisors': 20, '/people/person/profession': 21, '/business/company/founders': 22, '/sports/sports_team/location': 23}
Train loader: <torch.utils.data.dataloader.DataLoader 

## Model

In [9]:
class NERUnit(nn.Module):
    def __init__(self, ner2idx, hidden_size, share_hidden_size, dropout):
        super().__init__()
        self.ner2idx = ner2idx
        self.transform = nn.Linear(hidden_size * 2 + share_hidden_size, hidden_size)
        self.classifier = nn.Linear(hidden_size, len(ner2idx))
        self.norm = nn.LayerNorm(hidden_size)
        self.activation = nn.ELU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, h_ner, h_share, mask):
        length, batch_size, _ = h_ner.size()
        device = h_ner.device

        # Create entity span representations
        start_rep = h_ner.unsqueeze(1).repeat(1, length, 1, 1)
        end_rep = h_ner.unsqueeze(0).repeat(length, 1, 1, 1)

        # Combine representations
        combined = torch.cat([start_rep, end_rep, h_share], dim=-1)
        transformed = self.norm(self.transform(combined))
        activated = self.activation(self.dropout(transformed))
        output = torch.sigmoid(self.classifier(activated))

        # Apply masks
        diagonal_mask = torch.triu(torch.ones(batch_size, length, length)).to(device)
        diagonal_mask = diagonal_mask.permute(1, 2, 0)

        mask_s = mask.unsqueeze(1).repeat(1, length, 1)
        mask_e = mask.unsqueeze(0).repeat(length, 1, 1)
        combined_mask = diagonal_mask * mask_s * mask_e
        combined_mask = combined_mask.unsqueeze(-1).repeat(1, 1, 1, len(self.ner2idx))

        return output * combined_mask

class RelationUnit(nn.Module):
    def __init__(self, rel2idx, hidden_size, share_hidden_size, dropout):
        super().__init__()
        self.rel2idx = rel2idx
        self.transform = nn.Linear(hidden_size * 2 + share_hidden_size, hidden_size)
        self.classifier = nn.Linear(hidden_size, len(rel2idx))
        self.norm = nn.LayerNorm(hidden_size)
        self.activation = nn.ELU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, h_re, h_share, mask):
        length, batch_size, _ = h_re.size()

        # Create relation representations
        subj_rep = h_re.unsqueeze(1).repeat(1, length, 1, 1)
        obj_rep = h_re.unsqueeze(0).repeat(length, 1, 1, 1)

        # Combine representations
        combined = torch.cat([subj_rep, obj_rep, h_share], dim=-1)
        transformed = self.norm(self.transform(combined))
        activated = self.activation(self.dropout(transformed))
        output = torch.sigmoid(self.classifier(activated))

        # Apply masks
        mask_expanded = mask.unsqueeze(-1).repeat(1, 1, len(self.rel2idx))
        mask_subj = mask_expanded.unsqueeze(1).repeat(1, length, 1, 1)
        mask_obj = mask_expanded.unsqueeze(0).repeat(length, 1, 1, 1)
        combined_mask = mask_subj * mask_obj

        return output * combined_mask

class ConvAttentionLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, pre_channels, channels, dropout=0.1):
        super().__init__()
        self.n_heads = n_heads
        input_channels = hid_dim * 2 + pre_channels

        self.linear1 = nn.Linear(hid_dim, hid_dim, bias=False)
        self.linear2 = nn.Linear(hid_dim, hid_dim, bias=False)

        self.conv = nn.Sequential(
            nn.Dropout2d(dropout),
            nn.Conv2d(input_channels, channels, kernel_size=1),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.LeakyReLU(0.1, inplace=True),
        )
        self.score_layer = nn.Conv2d(channels, n_heads, kernel_size=1, bias=False)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.LeakyReLU(0.1, inplace=True)

    def forward(self, x, y, pre_conv=None, mask=None):
        ori_x, ori_y = x, y
        B, M, _ = x.size()
        B, N, _ = y.size()

        # Create feature map
        fea_map = torch.cat([
            x.unsqueeze(2).repeat_interleave(N, 2),
            y.unsqueeze(1).repeat_interleave(M, 1)
        ], -1).permute(0, 3, 1, 2).contiguous()

        if pre_conv is not None:
            fea_map = torch.cat([fea_map, pre_conv], 1)

        fea_map = self.conv(fea_map)
        scores = self.activation(self.score_layer(fea_map))

        if mask is not None:
            mask = mask.expand_as(scores)
            scores = scores.masked_fill(mask.eq(0), -9e10)

        # Apply attention
        x = self.linear1(self.dropout(x))
        y = self.linear2(self.dropout(y))

        out_x = torch.matmul(F.softmax(scores, -1),
                           y.view(B, N, self.n_heads, -1).transpose(1, 2))
        out_x = out_x.transpose(1, 2).contiguous().view(B, M, -1)

        out_y = torch.matmul(F.softmax(scores.transpose(2, 3), -1),
                           x.view(B, M, self.n_heads, -1).transpose(1, 2))
        out_y = out_y.transpose(1, 2).contiguous().view(B, N, -1)

        # Residual connections
        out_x = self.activation(out_x + x) + ori_x
        out_y = self.activation(out_y + y) + ori_y

        return out_x, out_y, fea_map

class ConvAttention(nn.Module):
    def __init__(self, hid_dim, n_heads, pre_channels, channels, layers, dropout):
        super().__init__()
        self.layers = nn.ModuleList([
            ConvAttentionLayer(hid_dim, n_heads,
                             pre_channels if i == 0 else channels,
                             channels, dropout)
            for i in range(layers)
        ])

    def forward(self, x, y, fea_map=None, mask=None):
        for layer in self.layers:
            x, y, fea_map = layer(x, y, fea_map, mask)
        return x, y, fea_map.permute(0, 2, 3, 1).contiguous()


In [10]:
class CARE(nn.Module):
    def __init__(self, ner2idx, rel2idx):
        super().__init__()

        self.ner_unit = NERUnit(ner2idx, config.HIDDEN_SIZE,
                               config.SHARE_HIDDEN_SIZE, config.DROPOUT)
        self.relation_unit = RelationUnit(rel2idx, config.HIDDEN_SIZE,
                                        config.SHARE_HIDDEN_SIZE, config.DROPOUT)

        self.dist_emb = nn.Embedding(20, config.DIST_EMB_SIZE)
        self.dropout = nn.Dropout(config.DROPOUT)

        # BERT components
        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
        self.bert = AutoModel.from_pretrained("bert-base-cased")

        # Co-attention mechanism
        self.conv_attention = ConvAttention(
            hid_dim=config.HIDDEN_SIZE,
            n_heads=1,
            pre_channels=config.DIST_EMB_SIZE,
            channels=config.SHARE_HIDDEN_SIZE,
            layers=config.CO_ATTENTION_LAYERS,
            dropout=config.DROPOUT
        )

    def forward(self, texts, mask, dist):
        # Encode with BERT
        encoded = self.tokenizer(texts, return_tensors="pt",
                                padding='longest', is_split_into_words=True)
        encoded = {k: v.to(mask.device) for k, v in encoded.items()}

        x = self.bert(**encoded)[0]
        if self.training:
            x = self.dropout(x)

        length = x.size(1)

        # Distance embeddings
        dist_emb = self.dist_emb(dist).permute(0, 3, 1, 2)

        # Create padding masks
        padding_mask = mask.unsqueeze(-1)
        mask1 = padding_mask.unsqueeze(1).repeat(1, length, 1, 1)
        mask2 = padding_mask.unsqueeze(0).repeat(length, 1, 1, 1)
        padding_mask = (mask1 * mask2).permute(2, 3, 0, 1)

        # Co-attention
        h_ner, h_re, h_share = self.conv_attention(x, x, dist_emb, padding_mask)

        # Rearrange for NER and RE units
        h_ner = h_ner.permute(1, 0, 2)
        h_re = h_re.permute(1, 0, 2)
        h_share = h_share.permute(1, 2, 0, 3)

        # Get predictions
        ner_pred = self.ner_unit(h_ner, h_share, mask)
        re_pred = self.relation_unit(h_re, h_share, mask)

        return ner_pred, re_pred

#### CAREModel
kiến trúc chính của mô hình CARE, bao gồm 3 thành phần chính:

1. Encoder Module: Input → BERT → MLPs (tách thành NER-specified, RE-specified).
    * Lớp CARE sử dụng BERT (self.bert = AutoModel.from_pretrained(...)).
    * Sau BERT, các vector đặc trưng được tách thành hai nhánh: cho NER và RE (thường qua các Linear layers/MLPs).

    * Kết hợp Distance Embedding: Sử dụng embedding cho khoảng cách giữa các token.
        * Lớp CARE có self.dist_emb = nn.Embedding(20, config.DIST_EMB_SIZE).
        * Ma trận khoảng cách được ánh xạ qua embedding này và kết hợp vào co-attention.

2. Co-Attention Module:  Kết hợp đặc trưng NER, RE với Distance Embedding. Qua Conv2d để tạo ra các đặc trưng chia sẻ (shared feature) và ma trận đồng chú ý (co-attention matrix).

    * Lớp ConvAttentionLayer và ConvAttention thực hiện việc này:
        * Nhận đầu vào là đặc trưng NER, RE, distance embedding.
        * Kết hợp (concatenate) các đặc trưng này.
        * Qua các lớp Conv2d để học tương tác cục bộ và đồng chú ý.
        * Trả về các đặc trưng đã attention hóa cho NER, RE, và shared feature.
3. Classification Module: Đặc trưng NER, RE (sau co-attention) được đưa vào các bảng (table) để dự đoán nhãn cho từng cặp vị trí (NER Table, RE Table).
    * Lớp NERUnit và RelationUnit:
        * Nhận đặc trưng đã attention hóa.
        * Kết hợp các cặp (start, end) hoặc (subject, object).
        * Qua các lớp Linear, LayerNorm, ELU, Dropout, Sigmoid để dự đoán xác suất nhãn cho từng cặp.
        * Đầu ra là các bảng (tensor) dự đoán NER và RE cho toàn bộ câu.


## Metric and Loss

In [13]:
class Metrics:
    def __init__(self, rel2idx, ner2idx, metric_type="micro"):
        self.rel2idx = rel2idx
        self.ner2idx = ner2idx
        self.metric_type = metric_type

    def count_predictions(self, ner_pred, ner_label, re_pred, re_label):
        ner_pred = (ner_pred >= 0.5).float()
        re_pred = (re_pred >= 0.5).float()

        # Get entity mask for relations
        ner_mask = torch.sum(ner_pred, dim=1).float()
        ner_mask = torch.sum(ner_mask, dim=-1).float()
        ner_mask = (ner_mask > 0).float()

        seq_len = ner_mask.size(0)
        e1_mask = ner_mask.unsqueeze(0).repeat(seq_len, 1, 1)
        e2_mask = ner_mask.unsqueeze(1).repeat(1, seq_len, 1)
        entity_pair_mask = e1_mask * e2_mask

        # Apply entity mask to relation predictions
        entity_pair_mask = entity_pair_mask.unsqueeze(-1).repeat(1, 1, 1, len(self.rel2idx))
        filtered_re_pred = re_pred * entity_pair_mask

        # Count predictions and gold labels
        pred_num = filtered_re_pred.sum().item()
        gold_num = re_label.sum().item()

        # Count correct predictions
        correct_re = filtered_re_pred + re_label
        correct_re = (correct_re == 2).float()

        # Check if entities are correctly identified
        ner_correct = ner_pred * ner_label
        ner_correct = torch.sum(ner_correct, dim=1)
        ner_correct = torch.sum(ner_correct, dim=-1)
        ner_correct = (ner_correct > 0).float()

        e1_correct = ner_correct.unsqueeze(0).repeat(seq_len, 1, 1)
        e2_correct = ner_correct.unsqueeze(1).repeat(1, seq_len, 1)
        entity_correct_mask = e1_correct * e2_correct
        entity_correct_mask = entity_correct_mask.unsqueeze(-1).repeat(1, 1, 1, re_label.size(-1))

        correct_re = correct_re * entity_correct_mask
        correct_num = correct_re.sum().item()

        return pred_num, gold_num, correct_num

    def count_entities(self, ner_pred, ner_label):
        ner_pred = (ner_pred >= 0.5).float()
        pred_num = ner_pred.sum().item()
        gold_num = ner_label.sum().item()
        correct_num = (ner_pred * ner_label).sum().item()
        return pred_num, gold_num, correct_num

def calculate_f1(pred_num, gold_num, correct_num):
    if pred_num == 0:
        precision = 0
    else:
        precision = correct_num / pred_num

    if gold_num == 0:
        recall = 0
    else:
        recall = correct_num / gold_num

    if precision + recall == 0:
        f1 = 0
    else:
        f1 = 2 * precision * recall / (precision + recall)

    return {"p": precision, "r": recall, "f": f1}

In [14]:
class CARELoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.bce_loss = nn.BCELoss(reduction='sum')

    def forward(self, ner_pred, ner_label, re_pred, re_label):
        seq_len = ner_pred.size(1)
        ner_loss = self.bce_loss(ner_pred, ner_label) / seq_len
        re_loss = self.bce_loss(re_pred, re_label) / seq_len
        return ner_loss + re_loss


## Training and Evaluation

In [15]:
# ... existing code ...

## Training and Evaluation Functions

def train_model(model, train_loader, dev_loader, ner2idx, rel2idx, device, dataset_name, epochs=10):
    """Train the CARE model"""
    print(f"Training CARE model on {dataset_name} dataset...")

    # Setup
    optimizer = optim.Adam(model.parameters(), lr=config.LR, weight_decay=config.WEIGHT_DECAY)
    loss_fn = CARELoss()
    metrics = Metrics(rel2idx, ner2idx, config.EVAL_METRIC)

    # Training loop
    best_f1 = 0
    best_model_path = f"care_best_{dataset_name.lower()}.pt"

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        total_steps = 0

        print(f"Epoch {epoch+1}/{epochs}")

        for batch_idx, batch in enumerate(tqdm(train_loader, desc=f"Training Epoch {epoch+1}")):
            texts, ner_labels, re_labels, mask, dist = batch
            ner_labels = ner_labels.to(device)
            re_labels = re_labels.to(device)
            mask = mask.to(device)
            dist = dist.to(device)

            optimizer.zero_grad()

            ner_pred, re_pred = model(texts, mask, dist)
            loss = loss_fn(ner_pred, ner_labels, re_pred, re_labels)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.CLIP)
            optimizer.step()

            total_loss += loss.item()
            total_steps += 1

        avg_loss = total_loss / total_steps
        print(f"Average training loss: {avg_loss:.4f}")

        # Evaluate on dev set
        dev_results = evaluate_model(model, dev_loader, rel2idx, ner2idx, loss_fn, device, "dev")
        triple_f1 = dev_results[0]["f"]

        print(f"Dev Triple F1: {triple_f1:.4f}")

        # Save best model
        if triple_f1 > best_f1:
            best_f1 = triple_f1
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_f1': best_f1,
                'ner2idx': ner2idx,
                'rel2idx': rel2idx,
                'config': config.__dict__
            }, best_model_path)
            print(f"New best model saved with F1: {best_f1:.4f}")

        # Save checkpoint every 5 epochs
        if (epoch + 1) % 5 == 0:
            checkpoint_path = f"care_checkpoint_{dataset_name.lower()}_epoch_{epoch+1}.pt"
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_loss,
                'ner2idx': ner2idx,
                'rel2idx': rel2idx,
                'config': config.__dict__
            }, checkpoint_path)
            print(f"Checkpoint saved: {checkpoint_path}")

    print(f"Training completed. Best F1: {best_f1:.4f}")
    return best_model_path

def evaluate_model(model, dataloader, rel2idx, ner2idx, loss_fn, device, split_name):
    """Evaluate the model on given dataloader"""
    model.eval()
    total_loss = 0
    total_steps = 0

    metrics = Metrics(rel2idx, ner2idx, config.EVAL_METRIC)
    entity_counts = [0, 0, 0]  # pred, gold, correct
    relation_counts = [0, 0, 0]

    with torch.no_grad():
        for batch in tqdm(dataloader, desc=f"Evaluating {split_name}"):
            texts, ner_labels, re_labels, mask, dist = batch
            ner_labels = ner_labels.to(device)
            re_labels = re_labels.to(device)
            mask = mask.to(device)
            dist = dist.to(device)

            ner_pred, re_pred = model(texts, mask, dist)
            loss = loss_fn(ner_pred, ner_labels, re_pred, re_labels)

            total_loss += loss.item()
            total_steps += 1

            # Count predictions
            pred_num, gold_num, correct_num = metrics.count_predictions(ner_pred, ner_labels, re_pred, re_labels)
            relation_counts[0] += pred_num
            relation_counts[1] += gold_num
            relation_counts[2] += correct_num

            # Count entities
            ner_pred_num, ner_gold_num, ner_correct_num = metrics.count_entities(ner_pred, ner_labels)
            entity_counts[0] += ner_pred_num
            entity_counts[1] += ner_gold_num
            entity_counts[2] += ner_correct_num

    avg_loss = total_loss / total_steps

    # Calculate F1 scores
    triple_result = calculate_f1(relation_counts[0], relation_counts[1], relation_counts[2])
    entity_result = calculate_f1(entity_counts[0], entity_counts[1], entity_counts[2])

    print(f"------ {split_name.upper()} Results ------")
    print(f"Loss: {avg_loss:.4f}")
    print(f"Entity - P: {entity_result['p']:.4f}, R: {entity_result['r']:.4f}, F1: {entity_result['f']:.4f}")
    print(f"Triple - P: {triple_result['p']:.4f}, R: {triple_result['r']:.4f}, F1: {triple_result['f']:.4f}")

    return triple_result, entity_result, avg_loss


## Train WEBLLG

In [None]:
set_seed(config.SEED)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Setup logging
logger = setup_logging("care_outputs")

# Train on WEBNLG dataset
print("=" * 50)
print("TRAINING ON WEBNLG DATASET")
print("=" * 50)

model_weblg = CARE(ner2idx_weblg, rel2idx_weblg)
model_weblg.to(device)

best_model_path_weblg = train_model(
    model_weblg,
    train_loader_weblg,
    dev_loader_weblg,
    ner2idx_weblg,
    rel2idx_weblg,
    device,
    'WEBNLG'
)

# Evaluate on test set
print("\nEvaluating WEBNLG model on test set...")
test_results_weblg = evaluate_model(
    model_weblg,
    test_loader_weblg,
    rel2idx_weblg,
    ner2idx_weblg,
    CARELoss(),
    device,
    "test"
)


Using device: cuda
TRAINING ON WEBNLG DATASET


2025-07-19 17:12:38.234542: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1752945158.587514      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1752945158.694995      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


model.safetensors:   0%|          | 0.00/436M [00:00<?, ?B/s]

Training CARE model on WEBNLG dataset...
Epoch 1/10


Training Epoch 1: 100%|██████████| 628/628 [03:32<00:00,  2.95it/s]


Average training loss: 429.2165


Evaluating dev: 100%|██████████| 63/63 [00:05<00:00, 11.33it/s]


------ DEV Results ------
Loss: 2.3302
Entity - P: 0.8639, R: 0.9505, F1: 0.9051
Triple - P: 0.8154, R: 0.0477, F1: 0.0901
Dev Triple F1: 0.0901
New best model saved with F1: 0.0901
Epoch 2/10


Training Epoch 2: 100%|██████████| 628/628 [03:44<00:00,  2.80it/s]


Average training loss: 1.4331


Evaluating dev: 100%|██████████| 63/63 [00:05<00:00, 11.31it/s]


------ DEV Results ------
Loss: 1.1871
Entity - P: 0.9095, R: 0.9753, F1: 0.9413
Triple - P: 0.8952, R: 0.5072, F1: 0.6475
Dev Triple F1: 0.6475
New best model saved with F1: 0.6475
Epoch 3/10


Training Epoch 3: 100%|██████████| 628/628 [03:42<00:00,  2.82it/s]


Average training loss: 0.7716


Evaluating dev: 100%|██████████| 63/63 [00:05<00:00, 11.29it/s]


------ DEV Results ------
Loss: 0.8423
Entity - P: 0.9279, R: 0.9714, F1: 0.9491
Triple - P: 0.8829, R: 0.7797, F1: 0.8281
Dev Triple F1: 0.8281
New best model saved with F1: 0.8281
Epoch 4/10


Training Epoch 4: 100%|██████████| 628/628 [03:41<00:00,  2.83it/s]


Average training loss: 0.5146


Evaluating dev: 100%|██████████| 63/63 [00:05<00:00, 11.31it/s]


------ DEV Results ------
Loss: 0.6848
Entity - P: 0.9492, R: 0.9603, F1: 0.9547
Triple - P: 0.8685, R: 0.8022, F1: 0.8340
Dev Triple F1: 0.8340
New best model saved with F1: 0.8340
Epoch 5/10


Training Epoch 5: 100%|██████████| 628/628 [03:43<00:00,  2.81it/s]


Average training loss: 0.3801


Evaluating dev: 100%|██████████| 63/63 [00:05<00:00, 11.34it/s]


------ DEV Results ------
Loss: 0.5971
Entity - P: 0.9607, R: 0.9714, F1: 0.9660
Triple - P: 0.8563, R: 0.9002, F1: 0.8777
Dev Triple F1: 0.8777
New best model saved with F1: 0.8777
Checkpoint saved: care_checkpoint_webnlg_epoch_5.pt
Epoch 6/10


Training Epoch 6: 100%|██████████| 628/628 [03:43<00:00,  2.81it/s]


Average training loss: 0.3121


Evaluating dev: 100%|██████████| 63/63 [00:05<00:00, 11.31it/s]


------ DEV Results ------
Loss: 0.5754
Entity - P: 0.9676, R: 0.9720, F1: 0.9698
Triple - P: 0.9063, R: 0.8786, F1: 0.8922
Dev Triple F1: 0.8922
New best model saved with F1: 0.8922
Epoch 7/10


Training Epoch 7: 100%|██████████| 628/628 [03:42<00:00,  2.83it/s]


Average training loss: 0.2568


Evaluating dev: 100%|██████████| 63/63 [00:05<00:00, 11.24it/s]


------ DEV Results ------
Loss: 0.4978
Entity - P: 0.9612, R: 0.9850, F1: 0.9730
Triple - P: 0.8877, R: 0.9029, F1: 0.8952
Dev Triple F1: 0.8952
New best model saved with F1: 0.8952
Epoch 8/10


Training Epoch 8: 100%|██████████| 628/628 [03:44<00:00,  2.80it/s]


Average training loss: 0.2372


Evaluating dev: 100%|██████████| 63/63 [00:05<00:00, 11.25it/s]


------ DEV Results ------
Loss: 0.4731
Entity - P: 0.9746, R: 0.9740, F1: 0.9743
Triple - P: 0.9268, R: 0.8876, F1: 0.9068
Dev Triple F1: 0.9068
New best model saved with F1: 0.9068
Epoch 9/10


Training Epoch 9: 100%|██████████| 628/628 [03:42<00:00,  2.83it/s]


Average training loss: 0.2152


Evaluating dev: 100%|██████████| 63/63 [00:05<00:00, 11.25it/s]


------ DEV Results ------
Loss: 0.5269
Entity - P: 0.9515, R: 0.9844, F1: 0.9677
Triple - P: 0.8882, R: 0.9218, F1: 0.9047
Dev Triple F1: 0.9047
Epoch 10/10


Training Epoch 10: 100%|██████████| 628/628 [03:42<00:00,  2.82it/s]


Average training loss: 0.1935


Evaluating dev: 100%|██████████| 63/63 [00:05<00:00, 11.30it/s]


------ DEV Results ------
Loss: 0.5457
Entity - P: 0.9634, R: 0.9779, F1: 0.9706
Triple - P: 0.9011, R: 0.9254, F1: 0.9130
Dev Triple F1: 0.9130
New best model saved with F1: 0.9130
Checkpoint saved: care_checkpoint_webnlg_epoch_10.pt
Training completed. Best F1: 0.9130

Evaluating WEBNLG model on test set...


Evaluating test: 100%|██████████| 88/88 [00:08<00:00, 10.65it/s]

------ TEST Results ------
Loss: 0.5583
Entity - P: 0.9592, R: 0.9776, F1: 0.9683
Triple - P: 0.8750, R: 0.9077, F1: 0.8910





## Train NYT

In [None]:
import torch
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

# Train on NYT dataset
set_seed(config.SEED)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Setup logging
logger = setup_logging("care_outputs")

print("TRAINING ON NYT DATASET")
print("=" * 50)

model_nyt = CARE(ner2idx_nyt, rel2idx_nyt)
model_nyt.to(device)

best_model_path_nyt = train_model(
    model_nyt,
    train_loader_nyt,
    dev_loader_nyt,
    ner2idx_nyt,
    rel2idx_nyt,
    device,
    'NYT'
)

# Evaluate on test set
print("\nEvaluating NYT model on test set...")
test_results_nyt = evaluate_model(
    model_nyt,
    test_loader_nyt,
    rel2idx_nyt,
    ner2idx_nyt,
    CARELoss(),
    device,
    "test"
)

Using device: cuda
TRAINING ON NYT DATASET
Training CARE model on NYT dataset...
Epoch 1/10


Training Epoch 1: 100%|██████████| 7025/7025 [20:15<00:00,  5.78it/s]


Average training loss: 22.3834


Evaluating dev: 100%|██████████| 625/625 [00:44<00:00, 13.90it/s]


------ DEV Results ------
Loss: 0.2920
Entity - P: 0.8897, R: 0.9645, F1: 0.9256
Triple - P: 0.7908, R: 0.8799, F1: 0.8329
Dev Triple F1: 0.8329
New best model saved with F1: 0.8329
Epoch 2/10


Training Epoch 2: 100%|██████████| 7025/7025 [20:34<00:00,  5.69it/s]


Average training loss: 0.2074


Evaluating dev: 100%|██████████| 625/625 [00:44<00:00, 14.07it/s]


------ DEV Results ------
Loss: 0.2130
Entity - P: 0.9399, R: 0.9454, F1: 0.9426
Triple - P: 0.8707, R: 0.8889, F1: 0.8797
Dev Triple F1: 0.8797
New best model saved with F1: 0.8797
Epoch 3/10


Training Epoch 3: 100%|██████████| 7025/7025 [20:07<00:00,  5.82it/s]


Average training loss: 0.1435


Evaluating dev: 100%|██████████| 625/625 [00:43<00:00, 14.29it/s]


------ DEV Results ------
Loss: 0.2099
Entity - P: 0.9489, R: 0.9469, F1: 0.9479
Triple - P: 0.9074, R: 0.8912, F1: 0.8992
Dev Triple F1: 0.8992
New best model saved with F1: 0.8992
Epoch 4/10


Training Epoch 4: 100%|██████████| 7025/7025 [20:09<00:00,  5.81it/s]


Average training loss: 0.1058


Evaluating dev: 100%|██████████| 625/625 [00:45<00:00, 13.59it/s]


------ DEV Results ------
Loss: 0.2332
Entity - P: 0.9461, R: 0.9540, F1: 0.9500
Triple - P: 0.8911, R: 0.9100, F1: 0.9004
Dev Triple F1: 0.9004
New best model saved with F1: 0.9004
Epoch 5/10


Training Epoch 5: 100%|██████████| 7025/7025 [21:46<00:00,  5.38it/s]


Average training loss: 0.0836


Evaluating dev: 100%|██████████| 625/625 [00:50<00:00, 12.31it/s]


------ DEV Results ------
Loss: 0.2587
Entity - P: 0.9477, R: 0.9540, F1: 0.9509
Triple - P: 0.8992, R: 0.9157, F1: 0.9074
Dev Triple F1: 0.9074
New best model saved with F1: 0.9074
Checkpoint saved: care_checkpoint_nyt_epoch_5.pt
Epoch 6/10


Training Epoch 6: 100%|██████████| 7025/7025 [21:33<00:00,  5.43it/s]


Average training loss: 0.0678


Evaluating dev: 100%|██████████| 625/625 [00:58<00:00, 10.69it/s]


------ DEV Results ------
Loss: 0.2789
Entity - P: 0.9491, R: 0.9549, F1: 0.9520
Triple - P: 0.9114, R: 0.9109, F1: 0.9111
Dev Triple F1: 0.9111
New best model saved with F1: 0.9111
Epoch 7/10


Training Epoch 7: 100%|██████████| 7025/7025 [22:03<00:00,  5.31it/s]


Average training loss: 0.0580


Evaluating dev: 100%|██████████| 625/625 [00:50<00:00, 12.47it/s]


------ DEV Results ------
Loss: 0.2770
Entity - P: 0.9515, R: 0.9560, F1: 0.9537
Triple - P: 0.9075, R: 0.9210, F1: 0.9142
Dev Triple F1: 0.9142
New best model saved with F1: 0.9142
Epoch 8/10


Training Epoch 8: 100%|██████████| 7025/7025 [20:34<00:00,  5.69it/s]


Average training loss: 0.0481


Evaluating dev: 100%|██████████| 625/625 [00:47<00:00, 13.07it/s]


------ DEV Results ------
Loss: 0.2970
Entity - P: 0.9455, R: 0.9570, F1: 0.9512
Triple - P: 0.9049, R: 0.9198, F1: 0.9123
Dev Triple F1: 0.9123
Epoch 9/10


Training Epoch 9: 100%|██████████| 7025/7025 [20:31<00:00,  5.70it/s]


Average training loss: 0.0413


Evaluating dev: 100%|██████████| 625/625 [00:43<00:00, 14.38it/s]


------ DEV Results ------
Loss: 0.3272
Entity - P: 0.9415, R: 0.9603, F1: 0.9508
Triple - P: 0.8900, R: 0.9283, F1: 0.9087
Dev Triple F1: 0.9087
Epoch 10/10


Training Epoch 10: 100%|██████████| 7025/7025 [20:30<00:00,  5.71it/s]


Average training loss: 0.0360


Evaluating dev: 100%|██████████| 625/625 [00:48<00:00, 12.85it/s]


------ DEV Results ------
Loss: 0.3372
Entity - P: 0.9336, R: 0.9679, F1: 0.9505
Triple - P: 0.8888, R: 0.9324, F1: 0.9101
Dev Triple F1: 0.9101
Checkpoint saved: care_checkpoint_nyt_epoch_10.pt
Training completed. Best F1: 0.9142

Evaluating NYT model on test set...


Evaluating test: 100%|██████████| 625/625 [00:44<00:00, 14.05it/s]

------ TEST Results ------
Loss: 0.3457
Entity - P: 0.9327, R: 0.9710, F1: 0.9515
Triple - P: 0.8819, R: 0.9303, F1: 0.9055





## Load and Test Model


In [16]:
def load_trained_model(model_path, device):
    """Load a trained CARE model"""
    checkpoint = torch.load(model_path, map_location=device)

    # Recreate model with saved config
    ner2idx = checkpoint['ner2idx']
    rel2idx = checkpoint['rel2idx']

    model = CARE(ner2idx, rel2idx)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()

    return model, ner2idx, rel2idx


In [None]:
from huggingface_hub import login
login(token="")

In [18]:
from huggingface_hub import hf_hub_download
import torch

model_path_web = hf_hub_download(
    repo_id="LuongDat/care_webnlg",
    filename="care_best_webnlg.pt"
)
model_path_nyt = hf_hub_download(
    repo_id="LuongDat/care_webnlg",
    filename="care_best_nyt.pt"
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
model_weblg, ner2idx_weblg, rel2idx_weblg = load_trained_model(model_path_web, device)
model_nyt, ner2idx_nyt, rel2idx_nyt = load_trained_model(model_path_nyt, device)


In [55]:
def predict_on_sample(model, tokenizer, text, ner2idx, rel2idx, device):
    """
    Dự đoán thực thể và quan hệ cho một câu đầu vào.
    - Nếu NER chỉ có 'None', chỉ trả về text, start, end.
    - Relations trả về subject, object, relation.
    """
    model.eval()
    words = text.split()
    if len(words) > config.MAX_SEQ_LEN:
        words = words[:config.MAX_SEQ_LEN]
    sent_str = ' '.join(words)
    bert_tokens = tokenizer.tokenize(sent_str)
    bert_len = len(bert_tokens) + 2 

    # Ánh xạ từ -> BERT
    word_to_bert = map_word_to_bert(words, tokenizer)

    # Tạo ma trận khoảng cách
    dis2idx = np.zeros(4000, dtype='int64')
    dis2idx[1] = 1
    dis2idx[2:] = 2
    dis2idx[4:] = 3
    dis2idx[8:] = 4
    dis2idx[16:] = 5
    dis2idx[32:] = 6
    dis2idx[64:] = 7
    dis2idx[128:] = 8
    dis2idx[256:] = 9

    dist_inputs = torch.zeros((bert_len, bert_len), dtype=torch.long)
    for k in range(bert_len):
        dist_inputs[k, :] += k
        dist_inputs[:, k] -= k
    for i in range(bert_len):
        for j in range(bert_len):
            if dist_inputs[i, j] < 0:
                dist_inputs[i, j] = int(dis2idx[-int(dist_inputs[i, j])]) + 9
            else:
                dist_inputs[i, j] = int(dis2idx[int(dist_inputs[i, j])])
    dist_inputs[dist_inputs == 0] = 19

    # Chuẩn bị batch
    texts = [words]
    mask = torch.ones(bert_len, 1).to(device)
    dist = dist_inputs.unsqueeze(0).to(device)

    with torch.no_grad():
        ner_pred, re_pred = model(texts, mask, dist)
    ner_pred = (ner_pred >= 0.5).float().squeeze(2)
    re_pred = (re_pred >= 0.5).float().squeeze(2)

    # Trích xuất entities (chỉ cần text, start, end nếu chỉ có 'None')
    entities = []
    for i in range(bert_len):
        for j in range(i, bert_len):
            for ner_type_idx, ner_type in enumerate(ner2idx.keys()):
                if ner_pred[i, j, ner_type_idx] == 1:
                    if ner_type == "None":
                        start_word = end_word = None
                        for word_idx, (start, end) in word_to_bert.items():
                            if start == i - 1:
                                start_word = word_idx
                            if end == j - 1:
                                end_word = word_idx
                        if start_word is not None and end_word is not None:
                            entity_text = ' '.join(words[start_word:end_word+1])
                            entities.append({
                                'text': entity_text,
                                'start': start_word,
                                'end': end_word
                            })

    # Trích xuất relations (luôn cần loại quan hệ)
    relations = []
    for i in range(bert_len):
        for j in range(bert_len):
            for rel_type_idx, rel_type in enumerate(rel2idx.keys()):
                if re_pred[i, j, rel_type_idx] == 1:
                    subj_entity = next((e for e in entities if e['start'] == i-1 and e['end'] == i-1), None)
                    obj_entity = next((e for e in entities if e['start'] == j-1 and e['end'] == j-1), None)
                    if subj_entity and obj_entity:
                        relations.append({
                            'subject': subj_entity['text'],
                            'object': obj_entity['text'],
                            'relation': rel_type
                        })
    return entities, relations

In [61]:
sample_texts_webnlg = [
    "Alan Bean ( of the United States ) was a crew member of NASA 's Apollo 12 under the commander David Scott .",
    "Alan Shepard , who was born in New Hampshire on November 18th , 1923 , graduated with a M.A . from NWC in 1957 and retired on August 1st , 1974 .",
    "American test pilot Alan Shepard died in California and was born in New Hampshire ."
]

for i, text in enumerate(sample_texts_webnlg, 1):
    print("-" * 40)
    print(f"\nSample {i}: {text}")
    print("-" * 40)

    # Predict with WEBNLG model
    print("WEBNLG Model Predictions:")
    entities_weblg, relations_weblg = predict_on_sample(
        model_weblg, tokenizer, text, ner2idx_weblg, rel2idx_weblg, device
    )

    print("Entities:")
    if entities_weblg:
        for entity in entities_weblg:
            print(f"  - {entity['text']}")
    else:
        print("No entities found.")

    print("Relations:")
    if relations_weblg:
        for relation in relations_weblg:
            print(f"  - {relation['subject']} --{relation['relation']}--> {relation['object']}")
    else:
        print("No relations found.")


----------------------------------------

Sample 1: Alan Bean ( of the United States ) was a crew member of NASA 's Apollo 12 under the commander David Scott .
----------------------------------------
WEBNLG Model Predictions:
Entities:
  - Bean
  - States
  - NASA
  - 12
  - Scott
Relations:
  - Bean --nationality--> States
----------------------------------------

Sample 2: Alan Shepard , who was born in New Hampshire on November 18th , 1923 , graduated with a M.A . from NWC in 1957 and retired on August 1st , 1974 .
----------------------------------------
WEBNLG Model Predictions:
Entities:
  - Shepard
  - Hampshire
Relations:
  - Shepard --birthPlace--> Hampshire
----------------------------------------

Sample 3: American test pilot Alan Shepard died in California and was born in New Hampshire .
----------------------------------------
WEBNLG Model Predictions:
Entities:
  - pilot
  - Shepard
  - California
  - Hampshire
Relations:
  - Shepard --occupation--> pilot
  - Shepard --

In [64]:
sample_texts_nyt = [
    "The New York Philharmonic will perform at Avery Fisher Hall in Lincoln Center .",
    "The Empire State Building is located in New York City .",
    "Harvard University is located in Cambridge , Massachusetts .",
    "The Statue of Liberty stands on Liberty Island in New York Harbor ."
]

for i, text in enumerate(sample_texts_nyt, 1):
    print("-" * 40)
    print(f"\nSample {i}: {text}")    

    # Predict with NYT model
    print("\nNYT Model Predictions:")
    entities_nyt, relations_nyt = predict_on_sample(
        model_nyt, tokenizer, text, ner2idx_nyt, rel2idx_nyt, device
    )

    print("Entities:")
    if entities_nyt:
        for entity in entities_nyt:
            print(f"  - {entity['text']}")
    else:
        print("No entities found.")

    print("Relations:")
    if relations_nyt:
        for relation in relations_nyt:
            print(f"  - {relation['subject']} --{relation['relation']}--> {relation['object']}")
    else:
        print("No relations found.")

----------------------------------------

Sample 1: The New York Philharmonic will perform at Avery Fisher Hall in Lincoln Center .

NYT Model Predictions:
Entities:
  - Philharmonic
  - Hall
  - Center
Relations:
  - Center --/location/location/contains--> Hall
----------------------------------------

Sample 2: The Empire State Building is located in New York City .

NYT Model Predictions:
Entities:
  - Building
  - City
Relations:
  - City --/location/location/contains--> Building
----------------------------------------

Sample 3: Harvard University is located in Cambridge , Massachusetts .

NYT Model Predictions:
Entities:
  - University
  - Cambridge
  - Massachusetts
Relations:
  - Cambridge --/location/location/contains--> University
  - Massachusetts --/location/location/contains--> Cambridge
----------------------------------------

Sample 4: The Statue of Liberty stands on Liberty Island in New York Harbor .

NYT Model Predictions:
Entities:
  - Island
  - Harbor
Relations:
