In [None]:
'''
Load raw dataset from ORKG statements
Clean statements and eliminate illegal characters
Split triples to train, test, valid .txt file
'''
from orkg import ORKG
import pickle
import random
import re
orkg = ORKG(host="https://orkg.org", creds=("zihao.wang@ipvs.uni-stuttgart.de", "HM8eZVFFPMAtXR6"))

response = orkg.resources.get_unpaginated(start_page=1, end_page=100)

'''
print(response.all_succeeded)
print(len(response.responses))
print(len(response.content))
print(response.content[:3])
'''
literal = orkg.literals.by_id(id="A1")


ent_by_id = orkg.statements.get(size=5000, sort='id', desc=True)
print(ent_by_id.content[:5000])

data = []
for item in ent_by_id.content:
    try:
        subject = item['subject']['label']
        predicate = item['predicate']['label']
        obj = item['object']['label']
        if "\n" in subject or "\t" in subject or "\n" in predicate or "\t" in predicate or "\n" in obj or "\t" in obj:
            continue
        if "path" in subject or "path" in predicate or "path" in obj:
            continue
        data.append((subject, predicate, obj))

    except (KeyError, TypeError):
        pass


def split_dataset(data, train_ratio, valid_ratio, test_ratio):
    random.shuffle(data)
    train_size = int(len(data) * train_ratio)
    valid_size = int(len(data) * valid_ratio)

    train_data = data[:train_size]
    valid_data = data[train_size:train_size + valid_size]
    test_data = data[train_size + valid_size:]

    return train_data, valid_data, test_data



train_data, valid_data, test_data = split_dataset(data, train_ratio=0.8, valid_ratio=0.1, test_ratio=0.1)



def remove_illegal_chars(text):
    # 定义非法字符的正则表达式模式
    illegal_chars_pattern = r"[^\w\s]"  # 匹配非字母、非数字、非下划线、非空白字符

    # 使用正则表达式替换非法字符为空字符串
    cleaned_text = re.sub(illegal_chars_pattern, "", text)

    return cleaned_text





# 保存训练集为txt文件
train_file_path = 'C://Python//learn_torch//dataset//train.txt'
with open(train_file_path, 'w', encoding='utf-8') as train_file:
    for item in train_data:

            subject = remove_illegal_chars(item[0])
            predicate = remove_illegal_chars(item[1])
            obj = remove_illegal_chars(item[2])
            train_file.write(f"{subject}\t{predicate}\t{obj}\n")




# 保存验证集为txt文件
valid_file_path = 'dataset/valid.txt'
with open(valid_file_path, 'w', encoding='utf-8') as valid_file:
    for item in valid_data:

            subject = remove_illegal_chars(item[0])
            predicate = remove_illegal_chars(item[1])
            obj = remove_illegal_chars(item[2])
            valid_file.write(f"{subject}\t{predicate}\t{obj}\n")






# 保存测试集为txt文件
test_file_path = 'dataset/test.txt'
with open(test_file_path, 'w', encoding='utf-8') as test_file:
    for item in test_data:

            subject = remove_illegal_chars(item[0])
            predicate = remove_illegal_chars(item[1])
            obj = remove_illegal_chars(item[2])
            test_file.write(f"{subject}\t{predicate}\t{obj}\n")



with open('dataset//train.txt', 'r', encoding='utf-8') as file:
    lines = file.readlines()

cleaned_lines = []
incomplete_lines = []
for i, line in enumerate(lines):
    fields = line.strip().split('\t')
    if len(fields) == 3 and all(fields):
        cleaned_lines.append(line)
    else:
        incomplete_lines.append(i+1)

with open('dataset//train.txt', 'w', encoding='utf-8') as new_file:
    for line in cleaned_lines:
        new_file.write(line)

print("不完整的三元组行行数：", incomplete_lines)

with open('dataset//test.txt', 'r', encoding='utf-8') as file:
    lines = file.readlines()

cleaned_lines = []
incomplete_lines = []
for i, line in enumerate(lines):
    fields = line.strip().split('\t')
    if len(fields) == 3 and all(fields):
        cleaned_lines.append(line)
    else:
        incomplete_lines.append(i+1)

with open('dataset//test.txt', 'w', encoding='utf-8') as new_file:
    for line in cleaned_lines:
        new_file.write(line)

print("不完整的三元组行行数：", incomplete_lines)



with open('dataset//valid.txt', 'r', encoding='utf-8') as file:
    lines = file.readlines()

cleaned_lines = []
incomplete_lines = []
for i, line in enumerate(lines):
    fields = line.strip().split('\t')
    if len(fields) == 3 and all(fields):
        cleaned_lines.append(line)
    else:
        incomplete_lines.append(i+1)

with open('dataset//valid.txt', 'w', encoding='utf-8') as new_file:
    for line in cleaned_lines:
        new_file.write(line)

print("不完整的三元组行行数：", incomplete_lines)

