In [41]:
import os
import math
import json
import uuid
import random
import pickle
import logging
import linecache
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from glob import glob
from random import shuffle

random.seed(46)
logging.basicConfig(format='%(asctime)s | %(levelname)s | %(message)s', level=logging.INFO, datefmt='%Y-%m-%d %H:%M:%S') 

In [42]:
with open("../data/3_openKE/relation2id.txt") as fp:
    next(fp)
    
    relations = [' '.join(line.strip().split(' ')[:-1]) for line in fp.readlines()]

print(relations)
print(len(relations))

['Process Create', 'Process Start', 'CreateFile', 'SetBasicInformationFile', 'SetDispositionInformationEx', 'SetDispositionInformationFile', 'WriteFile', 'TCP Connect', 'TCP Send', 'UDP Send', 'TCP Disconnect', 'RegQueryKey', 'RegQueryValue', 'CloseFile', 'QueryAllInformationFile', 'QueryAttributeTagFile', 'QueryBasicInformationFile', 'QueryDirectory', 'QueryNetworkOpenInformationFile', 'ReadFile', 'TCP Receive', 'UDP Receive', 'RegCreateKey', 'RegSetValue', 'RegCloseKey', 'RegDeleteValue', 'RegOpenKey']
27


In [43]:
type2attr = {
    "Process": "Cmdline", 
    "File": "Name", 
    "Registry": "Key", 
    "Network": "Dstaddress"
}

In [44]:
nodes = set()
edges = set()
all_tuple = set()

labels = set()
all_triplets = set()



def collect_resource(events):
    global nodes, edges, all_tuple, relations
    for e in events:
        
        if e["relation"] not in relations:
            print(e["relation"], ' is not in relation')
            print(e["relation"])
            continue
            
        srcNode = str(e["srcNode"][type2attr[e["srcNode"]["Type"]]])
        dstNode = str(e["dstNode"][type2attr[e["dstNode"]["Type"]]]) if e["dstNode"] != None else srcNode
        relation = e["relation"]
        label = e["label"]
        # print(label)
        
        nodes.add(srcNode)
        nodes.add(dstNode)
        edges.add(relation)
        labels.add(label)

        all_tuple.add((srcNode, relation, dstNode))
        all_triplets.add((srcNode, relation, dstNode, label))
        

path = glob(f'../data/TrainingData/*/number_*/expanded_instance.json')
for p in tqdm(path):
    with open(p) as fp:
        events = json.load(fp)
    collect_resource(events)

logging.info(f"Nodes: {len(nodes)}, Edges: {len(edges)}, Labels: {len(labels)}")

  0%|          | 0/16900 [00:00<?, ?it/s]

2024-02-02 22:54:10 | INFO | Nodes: 824642, Edges: 25, Labels: 278


In [68]:
# events

In [46]:
relations

['Process Create',
 'Process Start',
 'CreateFile',
 'SetBasicInformationFile',
 'SetDispositionInformationEx',
 'SetDispositionInformationFile',
 'WriteFile',
 'TCP Connect',
 'TCP Send',
 'UDP Send',
 'TCP Disconnect',
 'RegQueryKey',
 'RegQueryValue',
 'CloseFile',
 'QueryAllInformationFile',
 'QueryAttributeTagFile',
 'QueryBasicInformationFile',
 'QueryDirectory',
 'QueryNetworkOpenInformationFile',
 'ReadFile',
 'TCP Receive',
 'UDP Receive',
 'RegCreateKey',
 'RegSetValue',
 'RegCloseKey',
 'RegDeleteValue',
 'RegOpenKey']

In [47]:
len(relations)

27

In [48]:
entity2id = {n:idx for idx, n in tqdm(enumerate(list(nodes)))}
relation2id = {rel:idx for idx, rel in tqdm(enumerate(relations))}
label2id = {l:idx for idx, l in tqdm(enumerate(list(labels)))}

print(len(entity2id))
print(len(relation2id))
print(len(label2id))

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

824642
27
278


In [49]:
all_tuple = list(all_tuple)
# shuffle(all_tuple)
train_tuple = all_tuple[:int(len(all_tuple)*0.8)]
valid_tuple = all_tuple[int(len(all_tuple)*0.8):int(len(all_tuple)*0.9)]
test_tuple = all_tuple[int(len(all_tuple)*0.9):]

train2id = []
valid2id = []
test2id = []

for t in tqdm(train_tuple):
    train2id.append((entity2id[t[0]], entity2id[t[2]], relation2id[t[1]]))

for t in tqdm(valid_tuple):
    valid2id.append((entity2id[t[0]], entity2id[t[2]], relation2id[t[1]]))

for t in tqdm(test_tuple):
    test2id.append((entity2id[t[0]], entity2id[t[2]], relation2id[t[1]]))

  0%|          | 0/3498383 [00:00<?, ?it/s]

  0%|          | 0/437298 [00:00<?, ?it/s]

  0%|          | 0/437298 [00:00<?, ?it/s]

In [50]:
list(all_tuple)[0]

('6228',
 'QueryDirectory',
 'C:\\Users\\nsilva\\AppData\\Local\\Packages\\Microsoft.MicrosoftOfficeHub_8wekyb3d8bbwe\\AppData\\.git')

In [51]:
list(all_tuple)[0]

('6228',
 'QueryDirectory',
 'C:\\Users\\nsilva\\AppData\\Local\\Packages\\Microsoft.MicrosoftOfficeHub_8wekyb3d8bbwe\\AppData\\.git')

In [52]:
train2id[0]

(66161, 463231, 17)

In [53]:
print(entity2id['4379921687c9a557f14c36d23279d6d4b8304e34d2cc904f9c8b9efed7824778.bin'])
print(entity2id['HKLM\\System\\CurrentControlSet\\Control\\SafeBoot\\Option'])
print(entity2id['\x1f@028;0 2K5740 70 3@0=8FC A>B@C4=8:0<.exe'])
print(relation2id['RegOpenKey'])

58107
322972
760725
26


In [54]:
# # import re

# # file_path = '../data/3_openKE/entity2id.txt'

# # with open(file_path, 'r') as file:
# #     next(file)  # 跳过第一行
# #     ent2id = {}
# #     for line in file:
# #         match = re.match(r'(.+)\s(\d+)$', line.strip())
# #         if match:
# #             entity, entity_id = match.groups()
# #             ent2id[entity] = int(entity_id)

# # # 测试是否能找到特定的键
# # print(ent2id.get('\x1f@028;0 2K5740 70 3@0=8FC A>B@C4=8:0<.exe'))

# with open(file_path, 'r') as file:
#     next(file)
#     # 使用正则表达式去除行末的数字
#     dictionary = {re.sub(r'\s\d+$', '', line.strip()): index for index, line in enumerate(file)}

- entity2id['\x1f@028;0 2K5740 70 3@0=8FC A>B@C4=8:0<.exe']   
- turn to @028;0 2K5740 70 3@0=8FC A>B@C4=8:0<.exe 30215

In [55]:
# dictionary['\x1f@028;0 2K5740 70 3@0=8FC A>B@C4=8:0<.exe']

In [56]:
# with open (f'{save_path}/entity2id.pkl', 'rb') as fp:
#     ent2id = pickle.load(fp)

In [57]:
# ent2id['\x1f@028;0 2K5740 70 3@0=8FC A>B@C4=8:0<.exe']

In [58]:
import pickle

save_path = "../data/3_openKE"
if not os.path.isdir(f'{save_path}'):
    os.makedirs(f'{save_path}', exits_ok = True)
    
with open(f'{save_path}/entity2id.pkl', 'wb') as fp:
    pickle.dump(entity2id, fp)

with open(f'{save_path}/entity2id.txt', 'w') as fp:
    fp.write(f'{len(entity2id)}\n')
    for i, h in enumerate(entity2id):
        escaped_h = repr(h)[1:-1]
        fp.write(f'{escaped_h} {i}\n')
        
with open(f'{save_path}/relation2id.txt', 'w') as fp:
    fp.write(f'{len(relation2id)}\n')
    for i, r in enumerate(relation2id):
        fp.write(f'{r} {i}\n')

with open(f'{save_path}/train2id.txt', 'w') as fp:
    fp.write(f'{len(train2id)}\n')
    for i, (h, r, t) in enumerate(train2id):
        fp.write(f'{h} {r} {t}\n')
with open(f'{save_path}/valid2id.txt', 'w') as fp:
    fp.write(f'{len(valid2id)}\n')
    for i, (h, r, t) in enumerate(valid2id):
        fp.write(f'{h} {r} {t}\n')
with open(f'{save_path}/test2id.txt', 'w') as fp:
    fp.write(f'{len(test2id)}\n')
    for i, (h, r, t) in enumerate(test2id):
        fp.write(f'{h} {r} {t}\n')


In [59]:
# save_path = "../data/3_openKE"
# if not os.path.isdir(f'{save_path}'):
#     os.makedirs(f'{save_path}', exits_ok = True)

# with open(f'{save_path}/entity2id.txt', 'w') as fp:
#     fp.write(f'{len(entity2id)}\n')
#     for i, h in enumerate(entity2id):
#         escaped_h = repr(h)[1:-1]  # 转义实体名称并去除前后的引号
#         fp.write(f'{escaped_h} {i}\n')
        
# with open(f'{save_path}/relation2id.txt', 'w') as fp:
#     fp.write(f'{len(relation2id)}\n')
#     for i, r in enumerate(relation2id):
#         fp.write(f'{r} {i}\n')

# with open(f'{save_path}/train2id.txt', 'w') as fp:
#     fp.write(f'{len(train2id)}\n')
#     for i, (h, r, t) in enumerate(train2id):
#         fp.write(f'{h} {r} {t}\n')
# with open(f'{save_path}/valid2id.txt', 'w') as fp:
#     fp.write(f'{len(valid2id)}\n')
#     for i, (h, r, t) in enumerate(valid2id):
#         fp.write(f'{h} {r} {t}\n')
# with open(f'{save_path}/test2id.txt', 'w') as fp:
#     fp.write(f'{len(test2id)}\n')
#     for i, (h, r, t) in enumerate(test2id):
#         fp.write(f'{h} {r} {t}\n')

- Handling the output triplets txt file
    - with the format: 706546, 197778, 12, 73 (source, destination, relation, label)

In [60]:
# all_triplets = list(all_triplets)
# print(len(all_triplets))

# output_triplets = []

# for t in tqdm(all_triplets):
#     output_triplets.append((entity2id[t[0]], entity2id[t[2]], relation2id[t[1]], label2id[t[3]]))
#     # break

# print(output_triplets[0])

# with open(f'{save_path}/triplets.txt', 'w') as fp:
#     fp.write(f'{len(all_triplets)}\n')
    
#     for i, (h, r, t, l) in tqdm(enumerate(output_triplets)):
#         fp.write(f'{h} {r} {t} {l}\n')
#         if i >= 4335640:
#             print(i)

In [61]:
# with open(f'{save_path}/triplets.txt', 'r') as fp:
#     for i in range (5):
#         line = fp.readline()
#         print(line)

In [62]:
with open(f'{save_path}/label2id.txt', 'w') as fp:
    fp.write(f'{len(label2id)}\n')
    for i, r in enumerate(label2id):
        print(f'{r} {i}')
        fp.write(f'{r} {i}\n')

T1059.001_702bfdd2-9947-4eda-b551-c3a1ea9a59a2_B 0
T1078.001_d0ca00832890baa1d42322cf70fcab1a_B 1
T1074.001_e6dfc7e89359ac6fa6de84b0e1d5762e_B 2
T1491_68235976-2404-42a8-9105-68230cfef562_B 3
T1016_14a21534-350f-4d83-9dd7-3c56b93a0c17_B 4
T1491_47d08617-5ce1-424a-8cc5-c9c978ce6bf9_I 5
T1074.001_4e97e699-93d7-4040-b5a3-2e906a58199e_I 6
T1040_6881a4589710d53f0c146e91db513f01_B 7
T1547.009_b6e5c895c6709fe289352ee23f062229_B 8
T1564.001_66a5fd5f244819181f074dd082a28905_B 9
T1047_f4b0b4129560ea66f9751275e82f6bab_B 10
T1112_257313a3c93e3bb7dfb60d6753b09e34_I 11
T1047_ac2764f7a67a9ce92b54e8e59b361838_B 12
T1518.001_33a24ff44719e6ac0614b58f8c9a7c72_B 13
T1204.002_522f3f35cd013e63830fa555495a0081_I 14
T1059.001_ccdb8caf-c69e-424b-b930-551969450c57_B 15
T1105_0856c235a1d26113d4f2d92e39c9a9f8_B 16
T1547_fe9eeee9a7b339089e5fa634b08522c1_I 17
T1574.001_63bbedafba2f541552ac3579e9e3737b_B 18
T1137.002_e2af3c3ab1b0f659c874b8af58c49759_B 19
T1105_e6715e61f5df646692c624b3499384c4_B 20
T1105_4f683658f161

## Make the 16900 graph in all_graph_data.jsonl

In [63]:
label2id

{'T1059.001_702bfdd2-9947-4eda-b551-c3a1ea9a59a2_B': 0,
 'T1078.001_d0ca00832890baa1d42322cf70fcab1a_B': 1,
 'T1074.001_e6dfc7e89359ac6fa6de84b0e1d5762e_B': 2,
 'T1491_68235976-2404-42a8-9105-68230cfef562_B': 3,
 'T1016_14a21534-350f-4d83-9dd7-3c56b93a0c17_B': 4,
 'T1491_47d08617-5ce1-424a-8cc5-c9c978ce6bf9_I': 5,
 'T1074.001_4e97e699-93d7-4040-b5a3-2e906a58199e_I': 6,
 'T1040_6881a4589710d53f0c146e91db513f01_B': 7,
 'T1547.009_b6e5c895c6709fe289352ee23f062229_B': 8,
 'T1564.001_66a5fd5f244819181f074dd082a28905_B': 9,
 'T1047_f4b0b4129560ea66f9751275e82f6bab_B': 10,
 'T1112_257313a3c93e3bb7dfb60d6753b09e34_I': 11,
 'T1047_ac2764f7a67a9ce92b54e8e59b361838_B': 12,
 'T1518.001_33a24ff44719e6ac0614b58f8c9a7c72_B': 13,
 'T1204.002_522f3f35cd013e63830fa555495a0081_I': 14,
 'T1059.001_ccdb8caf-c69e-424b-b930-551969450c57_B': 15,
 'T1105_0856c235a1d26113d4f2d92e39c9a9f8_B': 16,
 'T1547_fe9eeee9a7b339089e5fa634b08522c1_I': 17,
 'T1574.001_63bbedafba2f541552ac3579e9e3737b_B': 18,
 'T1137.002_e2a

In [64]:
def collect_resource(events, entity2id, relation2id, label2id):
    global nodes, edges, all_tuple, relations
    processed_data_lines = []

    for e in events:
        if e["relation"] not in relations:
            continue

        srcNode = entity2id.get(str(e["srcNode"][type2attr[e["srcNode"]["Type"]]]))
        dstNode = entity2id.get(str(e["dstNode"][type2attr[e["dstNode"]["Type"]]]) if e["dstNode"] != None else srcNode)
        relation = relation2id.get(e["relation"])
        label = label2id.get(e["label"])

        if srcNode is not None and dstNode is not None and relation is not None:
            processed_data_lines.append(f"{srcNode} {dstNode} {relation} {label}")

    # print(processed_data_lines)
    return processed_data_lines


In [65]:
def process_data(data_lines, current_file_path):
    # for edge_index
    node_to_index = {}
    
    nodes = set()
    edges = []
    edge_attrs = []
    labels = []
    
    # for line in tqdm(data_lines, desc="Processing lines"):
    for line in data_lines:
        try:
            parts = line.strip().split()
            try:
                source_node = int(parts[0])
            except ValueError:
                source_node = parts[0]
                print(f"Non-integer source node encountered: {source_node} in line: {line}, file: {current_file_path}")

            try:
                dest_node = int(parts[1])
            except ValueError:
                dest_node = parts[1]
                print(f"Non-integer dest node encountered: {dest_node} in line: {line}, file: {current_file_path}")


            edge_id = int(parts[2])
            label = int(parts[3])

            nodes.add(source_node)
            nodes.add(dest_node)
            edges.append([source_node, dest_node])
            edge_attrs.append(edge_id)
            labels.append(label)
            
        except Exception as e:
            print(f"Error processing line: {line}\nError: {e}")
            continue
    
    for index, node in enumerate(nodes):
        node_to_index[node] = index

    updated_edges = []
    for edge in edges:
        src, dst = edge
        if isinstance(src, int) and isinstance(dst, int):
            updated_edges.append([node_to_index[src], node_to_index[dst]])
    edges = list(zip(*updated_edges)) if updated_edges else ([], [])
    
    return {
        "labels": labels,
        "num_nodes": len(nodes),
        "node_feat": list(nodes),
        "edge_attr": edge_attrs,
        "edge_index": edges
    }

In [66]:
lines = 0
path = glob(f'../data/TrainingData/*/number_*/expanded_instance.json')
OUTPUT_PATH = '../data/all_graph_data.jsonl'

# would be 167 x 100 graphs
for p in tqdm(path):
    with open(p) as fp:
        events = json.load(fp)
    # print(p)
    
    # get the info from the original json files first
    data_lines = collect_resource(events, entity2id, relation2id, label2id)

    # process the data into graph for each json file
    processed_data = process_data(data_lines, p)
    # print(processed_data)
    
    # break
    
    with open(OUTPUT_PATH, 'a') as f:
        f.write(json.dumps(processed_data))
        f.write('\n') 
        lines += 1
        
        # json.dump(processed_data, f)
        # if lines == 2: break

print(f"Output {lines} lines at {OUTPUT_PATH}!!")

  0%|          | 0/16900 [00:00<?, ?it/s]

Output 16900 lines at ../data/all_graph_data.jsonl!!


# n-n

- YR version

In [67]:
lef = {}
rig = {}
rellef = {}
relrig = {}

triple = open(f'{save_path}/train2id.txt', "r")
valid = open(f'{save_path}/valid2id.txt', "r")
test = open(f'{save_path}/test2id.txt', "r")

tot = (int)(triple.readline())
for i in tqdm(range(tot)):
    content = triple.readline()
    h,t,r = content.strip().split()
    if not (h,r) in lef:
        lef[(h,r)] = []
    if not (r,t) in rig:
        rig[(r,t)] = []
    lef[(h,r)].append(t)
    rig[(r,t)].append(h)
    if not r in rellef:
        rellef[r] = {}
    if not r in relrig:
        relrig[r] = {}
    rellef[r][h] = 1
    relrig[r][t] = 1

tot = (int)(valid.readline())
for i in tqdm(range(tot)):
    content = valid.readline()
    h,t,r = content.strip().split()
    if not (h,r) in lef:
        lef[(h,r)] = []
    if not (r,t) in rig:
        rig[(r,t)] = []
    lef[(h,r)].append(t)
    rig[(r,t)].append(h)
    if not r in rellef:
        rellef[r] = {}
    if not r in relrig:
        relrig[r] = {}
    rellef[r][h] = 1
    relrig[r][t] = 1

tot = (int)(test.readline())
for i in tqdm(range(tot)):
    content = test.readline()
    h,t,r = content.strip().split()
    if not (h,r) in lef:
        lef[(h,r)] = []
    if not (r,t) in rig:
        rig[(r,t)] = []
    lef[(h,r)].append(t)
    rig[(r,t)].append(h)
    if not r in rellef:
        rellef[r] = {}
    if not r in relrig:
        relrig[r] = {}
    rellef[r][h] = 1
    relrig[r][t] = 1

test.close()
valid.close()
triple.close()

f = open(f'{save_path}/type_constrain.txt', "w")
f.write("%d\n"%(len(rellef)))
for i in tqdm(rellef):
    f.write("%s\t%d"%(i,len(rellef[i])))
    for j in rellef[i]:
        f.write("\t%s"%(j))
    f.write("\n")
    f.write("%s\t%d"%(i,len(relrig[i])))
    for j in relrig[i]:
        f.write("\t%s"%(j))
    f.write("\n")
f.close()

  0%|          | 0/3498383 [00:00<?, ?it/s]

  0%|          | 0/437298 [00:00<?, ?it/s]

  0%|          | 0/437298 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

- euni version

In [None]:
# This is the code of openke

from tqdm import tqdm

lef = {}
rig = {}
rellef = {}
relrig = {}

with open(f"{data_path}/train2id.txt", "r") as triple:
    tot = int(triple.readline())
    for i in tqdm(range(tot), desc="Processing train data"):
        content = triple.readline()
        h, t, r = content.strip().split()
        if not (h, r) in lef:
            lef[(h, r)] = []
        if not (r, t) in rig:
            rig[(r, t)] = []
        lef[(h, r)].append(t)
        rig[(r, t)].append(h)
        if not r in rellef:
            rellef[r] = {}
        if not r in relrig:
            relrig[r] = {}
        rellef[r][h] = 1
        relrig[r][t] = 1

# 处理验证数据
with open(f"{data_path}/valid2id.txt", "r") as valid:
    tot = int(valid.readline())
    for i in tqdm(range(tot), desc="Processing validation data"):
        content = valid.readline()
        h, t, r = content.strip().split()
        if not (h, r) in lef:
            lef[(h, r)] = []
        if not (r, t) in rig:
            rig[(r, t)] = []
        lef[(h, r)].append(t)
        rig[(r, t)].append(h)
        if not r in rellef:
            rellef[r] = {}
        if not r in relrig:
            relrig[r] = {}
        rellef[r][h] = 1
        relrig[r][t] = 1

# 处理测试数据
with open(f"{data_path}/test2id.txt", "r") as test:
    tot = int(test.readline())
    for i in tqdm(range(tot), desc="Processing test data"):
        content = test.readline()
        h, t, r = content.strip().split()
        if not (h, r) in lef:
            lef[(h, r)] = []
        if not (r, t) in rig:
            rig[(r, t)] = []
        lef[(h, r)].append(t)
        rig[(r, t)].append(h)
        if not r in rellef:
            rellef[r] = {}
        if not r in relrig:
            relrig[r] = {}
        rellef[r][h] = 1
        relrig[r][t] = 1

# 保存类型约束
with open(f"{data_path}/type_constrain.txt", "w") as f:
    f.write("%d\n" % (len(rellef)))
    for i in rellef:
        f.write("%s\t%d" % (i, len(rellef[i])))
        for j in rellef[i]:
            f.write("\t%s" % (j))
        f.write("\n")
        f.write("%s\t%d" % (i, len(relrig[i])))
        for j in relrig[i]:
            f.write("\t%s" % (j))
        f.write("\n")

# 清空并准备计算关系
rellef = {}
totlef = {}
relrig = {}
totrig = {}
for i in lef:
    if not i[1] in rellef:
        rellef[i[1]] = 0
        totlef[i[1]] = 0
    rellef[i[1]] += len(lef[i])
    totlef[i[1]] += 1.0

for i in rig:
    if not i[0] in relrig:
        relrig[i[0]] = 0
        totrig[i[0]] = 0
    relrig[i[0]] += len(rig[i])
    totrig[i[0]] += 1.0

s11 = 0
s1n = 0
sn1 = 0
snn = 0

# 计算关系类型
with open(f"{data_path}/test2id.txt", "r") as f:
    tot = int(f.readline())
    for i in tqdm(range(tot), desc="Calculating relation types"):
        content = f.readline()
        h, t, r = content.strip().split()
        rign = rellef[r] / totlef[r]
        lefn = relrig[r] / totrig[r]
        if (rign < 1.5 and lefn < 1.5):
            s11 += 1
        if (rign >= 1.5 and lefn < 1.5):
            s1n += 1
        if (rign < 1.5 and lefn >= 1.5):
            sn1 += 1
        if (rign >= 1.5 and lefn >= 1.5):
            snn += 1

# 分类测试数据
with open(f"{data_path}/test2id.txt", "r") as f, \
     open(f"{data_path}/1-1.txt", "w") as f11, \
     open(f"{data_path}/1-n.txt", "w") as f1n, \
     open(f"{data_path}/n-1.txt", "w") as fn1, \
     open(f"{data_path}/n-n.txt", "w") as fnn, \
     open(f"{data_path}/test2id_all.txt", "w") as fall:

    tot = int(f.readline())
    fall.write("%d\n" % tot)
    f11.write("%d\n" % s11)
    f1n.write("%d\n" % s1n)
    fn1.write("%d\n" % sn1)
    fnn.write("%d\n" % snn)

    for i in tqdm(range(tot), desc="Sorting test data"):
        content = f.readline()
        h, t, r = content.strip().split()
        rign = rellef[r] / totlef[r]
        lefn = relrig[r] / totrig[r]
        if (rign < 1.5 and lefn < 1.5):
            f11.write(content)
            fall.write("0\t" + content)
        if (rign >= 1.5 and lefn < 1.5):
            f1n.write(content)
            fall.write("1\t" + content)
        if (rign < 1.5 and lefn >= 1.5):
            fn1.write(content)
            fall.write("2\t" + content)
        if (rign >= 1.5 and lefn >= 1.5):
            fnn.write(content)
            fall.write("3\t" + content)

In [53]:
lef = {}
rig = {}
rellef = {}
relrig = {}

triple = open(f'{save_path}/train2id.txt', "r")
valid = open(f'{save_path}/valid2id.txt', "r")
test = open(f'{save_path}/test2id.txt', "r")

tot = (int)(triple.readline())
for i in tqdm(range(tot)):
    content = triple.readline()
    h,t,r = content.strip().split()
    if not (h,r) in lef:
        lef[(h,r)] = []
    if not (r,t) in rig:
        rig[(r,t)] = []
    lef[(h,r)].append(t)
    rig[(r,t)].append(h)
    if not r in rellef:
        rellef[r] = {}
    if not r in relrig:
        relrig[r] = {}
    rellef[r][h] = 1
    relrig[r][t] = 1

tot = (int)(valid.readline())
for i in tqdm(range(tot)):
    content = valid.readline()
    h,t,r = content.strip().split()
    if not (h,r) in lef:
        lef[(h,r)] = []
    if not (r,t) in rig:
        rig[(r,t)] = []
    lef[(h,r)].append(t)
    rig[(r,t)].append(h)
    if not r in rellef:
        rellef[r] = {}
    if not r in relrig:
        relrig[r] = {}
    rellef[r][h] = 1
    relrig[r][t] = 1

tot = (int)(test.readline())
for i in tqdm(range(tot)):
    content = test.readline()
    h,t,r = content.strip().split()
    if not (h,r) in lef:
        lef[(h,r)] = []
    if not (r,t) in rig:
        rig[(r,t)] = []
    lef[(h,r)].append(t)
    rig[(r,t)].append(h)
    if not r in rellef:
        rellef[r] = {}
    if not r in relrig:
        relrig[r] = {}
    rellef[r][h] = 1
    relrig[r][t] = 1

test.close()
valid.close()
triple.close()

f = open(f'{save_path}/type_constrain.txt', "w")
f.write("%d\n"%(len(rellef)))
for i in tqdm(rellef):
    f.write("%s\t%d"%(i,len(rellef[i])))
    for j in rellef[i]:
        f.write("\t%s"%(j))
    f.write("\n")
    f.write("%s\t%d"%(i,len(relrig[i])))
    for j in relrig[i]:
        f.write("\t%s"%(j))
    f.write("\n")
f.close()