In [None]:
import nltk
nltk.download('stopwords')
nltk.download('words')

import json
from tqdm import tqdm
from stanfordcorenlp import StanfordCoreNLP
import re
import string
import pandas as pd
import json
from nltk.corpus import stopwords
from nltk.corpus import words

## Get Sentences

In [None]:
# dataset = 'msvd'
dataset = 'msrvtt'

def msvd_load_captions(caption_fpath):
    df = pd.read_csv(caption_fpath)
    df = df[df['Language'] == 'English']
    df = df[pd.notnull(df['Description'])]
    captions = df['Description'].values
    return captions

if dataset == 'msvd':
    cnt = 0
    train_id_path = '../MSVD/metadata/train.list'
    train_sentence_path = '../MSVD/metadata/train.csv'
    sentence_path = '../MSVD/metadata/msvd_sentence.txt'
    w = open(sentence_path, 'a')
    sentence = msvd_load_captions(train_sentence_path)
    print(sentence.size)
    for i in range(sentence.size):
        w.writelines(sentence[i])
        w.write('\n')
        cnt += 1
    w.close()
    print(cnt)

if dataset == 'msrvtt':
    file_path = '../MSR-VTT/metadata/train.json'
    sentence_path = '../MSR-VTT/metadata/msrvtt_sentence.txt'
    list1 = []
    cnt = 0

    with open(file_path, 'r') as f:
        load_data = json.load(f)
        f.close()
    w = open(sentence_path, 'a')
    # print(load_data, type(load_data))
    for i in load_data.keys():
        list1.append(load_data[i])
    for j in range(len(list1)):
        for key in list1[j]:
            w.writelines(list1[j][key])
            w.write('\n')
            cnt += 1
            print(list1[j][key])
    w.close()
    print(cnt)

##  Extract triples from sentences using Standford NLP

In [None]:
nlp = StanfordCoreNLP('stanford-corenlp-4.2.2', lang='en')

dataset = 'msvd'
sentence_path = dataset + 'sentence.txt'
entity_rel_path = dataset + 'entity_total.txt'
log_path = 'error_log.txt'

f = open(sentence_path, 'r')
w = open(entity_rel_path, 'w')
log = open(log_path, 'w')
lines = f.readlines()
cnt = 0
num_rel = 0

for line in tqdm(lines):
    sentence = line.strip()
    try:
        output = nlp.annotate(sentence, properties={
            "annotators": "tokenize,lemma,ssplit,pos,depparse,natlog,openie",
            "outputFormat": "json",
            'openie.triple.strict': 'true',
            'openie.max_entailments_per_clause': '1'
        })

        data = json.loads(output)
        for i in range(len(data['sentences'])):
            result = [data["sentences"][i]["openie"]]
            lemmas = data["sentences"][i]["tokens"]
            cnt += 1

            for g in result:
                for rel in g:
                    subj_start, subj_end = rel['subjectSpan']
                    obj_start, obj_end = rel['objectSpan']
                    rel_start, rel_end = rel['relationSpan']

                    if max(subj_end, obj_end, rel_end) > len(lemmas):
                        log.write(f"[Warning] Span out of range at sentence {cnt}: {sentence}\n")
                        continue

                    # Sử dụng join thay vì nối chuỗi ngược
                    l_subject = ' '.join([token['lemma'] for token in lemmas[subj_start:subj_end]])
                    l_object = ' '.join([token['lemma'] for token in lemmas[obj_start:obj_end]])
                    # l_relation = ' '.join([token['lemma'] for token in lemmas[rel_start:rel_end]])

                    subj_tokens = [
                        token['lemma'] for token in lemmas[subj_start:subj_end]
                        if token['pos'].startswith('NN')
                    ]
                    l_subject = ' '.join(subj_tokens)

                    obj_tokens = [
                        token['lemma'] for token in lemmas[obj_start:obj_end]
                        if token['pos'].startswith('NN')
                    ]
                    l_object = ' '.join(obj_tokens)

                    relation_tokens = [
                        token['lemma'] for token in lemmas[rel_start:rel_end]
                        if token['pos'].startswith('VB')
                    ]
                    l_relation = ' '.join(relation_tokens)

                    if not l_subject or not l_object or not l_relation:
                        continue

                    relationSent = f"{l_subject.strip()} &{l_object.strip()} &{l_relation.strip()}"
                    w.write(relationSent + '\n')
                    num_rel += 1

    except Exception as e:
        log.write(f"[Error] Exception at sentence {cnt}: {sentence}\n")
        log.write(f"         Error message: {str(e)}\n")
w.close()
f.close()
log.close()
nlp.close()

In [None]:
print('Result written to:', entity_rel_path)
print('Error log written to:', log_path)
print('Total number of processed sentences: ' + str(cnt))
print('Total number of relations extracted: ' + str(num_rel))

## Filter triplets

In [None]:
english_vocab = set(words.words())
stop_words = set(stopwords.words('english'))
akg_vocab = None # replace vocabulary of model

def load_triples(filepath):
    triples = []
    with open(filepath, 'r') as f:
        for line in f:
            parts = line.strip().split('&')
            if len(parts) == 3:
                head, tail, relation = parts
                triples.append((head.strip(), tail.strip(), relation.strip()))
    return triples

def is_valid_entity(entity):
    tokens = entity.split()
    for token in tokens:
        if token in akg_vocab:
            continue
        if len(token) == 1 or not token.isalpha() or token not in english_vocab:
            return False
    return True

def filter_triples(triples):
    remove_relations = {"be", "have", "do", "get", "can", "is", "are"}
    filtered = []
    for h, r, t in triples:
        if r in remove_relations:
            continue
        if len(h.split()) < 1 or len(t.split()) < 1:
            continue
        if len(h.split()) > 2 or len(t.split()) > 2:
            continue
        if not is_valid_entity(h) or not is_valid_entity(t):
            continue
        filtered.append((h, r, t))
    return filtered

def normalize_triples(triples):
    def normalize_text(text):
        quantifiers = {
            "some", "many", "several", "few", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "ten",
            "all", "both", "each", "every", "no", "none", "any", "most", "more", "less"
        }
        quantifiers.update(str(i) for i in range(1, 51))
        text = text.lower()
        text = re.sub(rf"[{string.punctuation}]", "", text)
        tokens = text.strip().split()
        tokens = [w for w in tokens if w not in stop_words.union(quantifiers)]
        return ' '.join(tokens)
    norm_triples = []
    for h, r, t in triples:
        norm_h = normalize_text(h)
        norm_t = normalize_text(t)
        norm_r = normalize_text(r)
        if norm_h and norm_t and norm_r:
            norm_triples.append((norm_h, norm_r, norm_t))
    return norm_triples

def deduplicate_triples(triples):
    return list(set(triples))

triples = load_triples(dataset + "_entity_total.txt")
triples = normalize_triples(triples)
triples = filter_triples(triples)
triples = deduplicate_triples(triples)

with open(dataset + "_filtered_entities.txt", "w") as f:
    for h, t, r in triples:
        relationSent = f"{h} &{t} &{r}"
        f.write(relationSent + '\n')
print('number of triples:', len(triples))

## Build dataset for TransE

In [None]:
import random
import os

def generate_mappings(triples):
    entity2id, relation2id = {}, {}
    eid, rid = 0, 0
    for h, r, t in triples:
        for ent in [h, t]:
            if ent not in entity2id:
                entity2id[ent] = eid
                eid += 1
        if r not in relation2id:
            relation2id[r] = rid
            rid += 1
    return entity2id, relation2id

def split_data(triples, val_ratio=0.1, test_ratio=0.1):
    random.seed(42) 
    random.shuffle(triples)
    n = len(triples)
    test_size = int(n * test_ratio)
    val_size = int(n * val_ratio)
    test_triples = triples[:test_size]
    val_triples = triples[test_size:test_size+val_size]
    train_triples = triples[test_size+val_size:]
    return train_triples, val_triples, test_triples

def write_output(train_triples, val_triples, test_triples, entity2id, relation2id, path):
    files = [
        'train2id.txt',
        'valid2id.txt',
        'test2id.txt',
        'entity2id.txt',
        'relation2id.txt'
    ]
    
    # Xóa file cũ nếu tồn tại
    for file in files:
        file_path = os.path.join(path, file)
        if os.path.exists(file_path):
            os.remove(file_path)

    # Train
    with open(os.path.join(path, 'train2id.txt'), 'w') as f:
        f.write(f"{len(train_triples)}\n")
        for h, r, t in train_triples:
            f.write(f"{entity2id[h]}\t{entity2id[t]}\t{relation2id[r]}\n")

    # Validation
    with open(os.path.join(path, 'valid2id.txt'), 'w') as f:
        f.write(f"{len(val_triples)}\n")
        for h, r, t in val_triples:
            f.write(f"{entity2id[h]}\t{entity2id[t]}\t{relation2id[r]}\n")

    # Test
    with open(os.path.join(path, 'test2id.txt'), 'w') as f:
        f.write(f"{len(test_triples)}\n")
        for h, r, t in test_triples:
            f.write(f"{entity2id[h]}\t{entity2id[t]}\t{relation2id[r]}\n")

    # Entities
    with open(os.path.join(path, 'entity2id.txt'), 'w') as f:
        f.write(f"{len(entity2id)}\n")
        for e, i in entity2id.items():
            f.write(f"{e}\t{i}\n")

    # Relations
    with open(os.path.join(path, 'relation2id.txt'), 'w') as f:
        f.write(f"{len(relation2id)}\n")
        for r, i in relation2id.items():
            f.write(f"{r}\t{i}\n")

entity2id, relation2id = generate_mappings(triples)
train_triples, val_triples, test_triples = split_data(triples)

print(f"Number of entities: {len(entity2id)}")
print(f"Number of relations: {len(relation2id)}")
print(f"Train triples: {len(train_triples)}")
print(f"Validation triples: {len(val_triples)}")
print(f"Test triples: {len(test_triples)}")

write_output(train_triples, val_triples, test_triples, entity2id, relation2id, path=f"./benchmarks/{dataset}")