In [20]:
import collections
import json
import os
import re
import openai
import pdfplumber
import ahocorasick

dataDir = "../data/"
dataName = "Deep Learning.pdf"

# 设置 api


In [57]:
# openai.api_base = "https://api.chatanywhere.com.cn/"
openai.api_base = 'https://api.chatanywhere.cn/'
openai.api_key = "sk-D1u13WweY1LhWLqv95Ml7e3y8f8ToSfsTkGnlgvSQLqZJptC"

# 一，读取数据


## 生成目录架构

- 生成**章节名称**与**章节序号**的对应：`name_dict`
- 以及**章节序号**与**章节页码**范围的对应：`index_dict`


In [3]:
with pdfplumber.open(dataDir + dataName) as f:
    # 目录架构生成
    c, p, n = [], [], []
    for i in range(7):
        page = f.pages[i]
        text = page.extract_text()
        text_split = text.split("\n")
        for i in text_split:
            if bool(re.match("[0-9]+\.[0-9]+", i.split(" ")[0])):
                c.append(i.split(" ")[0])
                p.append(i.split(" ")[-1])
            if bool(re.match("[0-9]+", i.split(" ")[0])):
                for j in i.split(" "):
                    if bool(re.match("[A-Za-z]+", j)):
                        n.append((i.split(" ")[0], j))

### name_dict


In [4]:
name_dict = {}
for i, j in n:
    if i in name_dict:
        name_dict[i] = name_dict.get(i, "") + " " + j
    else:
        name_dict[i] = name_dict.get(i, "") + j

### index_dict


In [5]:
p_range = list(zip(p, p[1:]))
p_range.append((720, 800))
c_p_range = list(zip(c, p_range))
index_dict = collections.defaultdict(list)
for k, v in c_p_range:
    index_dict[k.split(".")[0]].append((k, v))

## 生成内容表

- 段落内容表：`content_dict`


In [6]:
with pdfplumber.open(dataDir + dataName) as f:
    content_dict = collections.defaultdict(list)

    for k, v in index_dict.items():
        for i in v:
            page_range = i[-1]

            for j in range(int(page_range[0]), int(page_range[1])):
                page = f.pages[j]

                text = page.extract_text().replace("\n", " ")

                content_dict[i[0]].append(text)

# 二，Build Relations

## 目录 + 前置

In [7]:
def catalogue_relations(dataName, name_dict=None, relation_type=["目录", "前置"]):
    c_relations = []
    p_relations = []
    if not name_dict:
        name_dict = {}
    for k, v in name_dict.items():
        if bool(re.match("[0-9]+\.[0-9]+", k)):
            p_relations.append(
                [
                    k.split(".")[0] + " " + name_dict[k.split(".")[0]],
                    relation_type[1],
                    k + " " + v,
                ]
            )
        else:
            c_relations.append(
                [
                    dataName,
                    relation_type[0],
                    k + " " + v,
                ]
            )
    return p_relations, c_relations

In [8]:
p_relations, c_relations = catalogue_relations(
    dataName=dataName.split(".")[0], name_dict=name_dict
)

## 包含 + 段落共现

### AC自动机

In [9]:
def build(patterns):
    trie = ahocorasick.Automaton()
    for index, word in enumerate(patterns):
        trie.add_word(word, (index, word))
    trie.make_automaton()
    return trie

In [27]:
data = []
for i in range(1, 21):
    with open(dataDir + "relations/" + "kb_chapter_" + str(i) + ".json", "r") as f:
        data.append([i, json.load(f)])

In [172]:
re_set = set()
for i in data:
    for j in i[1]:
        if not bool(re.match('^([0-9]+|[A-Za-z])$', j[0])) and not bool(re.match('^([0-9]+|[A-Za-z])$', j[2])) and j[0] != j[2]:
            re_set.add(tuple(j))

In [28]:
raw_entity_set = set()
relation_set = set()
for i in data:
    for relation in i[1]:
        if relation[0] not in raw_entity_set:
            raw_entity_set.add(relation[0])
        if relation[2] not in raw_entity_set:
            raw_entity_set.add(relation[2])
        if relation[1] not in relation_set:
            relation_set.add(relation[1])

In [29]:
entity_set = set()
for i in raw_entity_set:
    if not bool(re.match('^([0-9]+|[A-Za-z])$', i)):
        entity_set.add(i)

In [30]:
def include_co_presence(entity_set):
    patterns = list(entity_set)
    trie = build(patterns)
    include_relations = []
    co_presence_relations = set()
    for chapter, sections in index_dict.items():
        for section in sections:
            for content in content_dict[section[0]]:
                word_set = set(word[1][1] for word in trie.iter(content))
                for head in word_set:
                    if not bool(re.match("^(\d+|[A-Za-z])$", head)):
                        include_relations.append(
                            [section[0] + " " + name_dict[section[0]], "include", head]
                        )
                        for tail in word_set:
                            if head != tail and not bool(
                                re.match("^(\d+|[A-Za-z])$", tail)
                            ):
                                if (
                                    tail,
                                    "co_presence",
                                    head,
                                ) not in co_presence_relations:
                                    co_presence_relations.add(
                                        (head, "co_presence", tail)
                                    )
    return include_relations, co_presence_relations

In [31]:
include_relations, co_presence_relations = include_co_presence(entity_set)

## 结果展示

In [32]:
p_relations[0]

['1 Introduction', '前置', '1.1 Who Should Read This Book?']

In [33]:
c_relations[0]

['Deep Learning', '目录', '1 Introduction']

In [34]:
len(co_presence_relations)

1533

In [35]:
include_relations[0]

['1.1 Who Should Read This Book?', 'include', 'Feedforward']

In [36]:
structure_relations = (
    p_relations + c_relations + list(co_presence_relations) + include_relations
)

In [37]:
len(structure_relations)

3112

In [38]:
with open(os.path.join(dataDir + "/relations", f"structure_relations.json"), "w") as f:
    json.dump(structure_relations, f, indent=4)

# 测试

# 一些其他尝试

N - Gram


In [39]:
# 暂时没有探索结果
# all_text = re.sub('[^A-Za-z0-9\.]+', ' ', text).lower().split(' ')
# ng1 = collections.defaultdict(int)
# ng2 = collections.defaultdict(int)
# ng3 = collections.defaultdict(int)
# ng4 = collections.defaultdict(int)
# for i, j in enumerate(all_text):
#     ng1[j] += 1
#     if i > 0: ng2[(all_text[i-1], j)] += 1
#     if i > 1: ng3[(all_text[i-2], all_text[i-1], j)] += 1
#     if i > 2: ng4[(all_text[i-3], all_text[i-2], all_text[i-1], j)] += 1

# def sort_feq(dic):
#     return sorted([(k, v) for k, v in dic.items()], key=lambda x: x[1], reverse=True)

### 定义 Chat 类

In [58]:
def total_counts(response):
    tokens_nums = int(response["usage"]["total_tokens"])
    price = 0.002 / 1000
    cost = "{:.5f}".format(price * tokens_nums * 7.5)
    print(f"tokens: {tokens_nums}, cost: {cost}")

    return float(cost)

In [65]:
class Chat:
    def __init__(self, conversation_list=[]):
        self.conversation_list = conversation_list
        self.costs_list = []

    def show_conversation(self, msg_list):
        for msg in msg_list[-2:]:
            if msg["role"] == "user":
                pass
            else:
                message = msg["content"]
                print(f"\U0001f47D: {message}\n")
            print()

    def ask(self, prompt):
        self.conversation_list.append({"role": "user", "content": prompt})
        openai.api_key = "sk-D1u13WweY1LhWLqv95Ml7e3y8f8ToSfsTkGnlgvSQLqZJptC"
        response = openai.ChatCompletion.create(
            model="gpt-3.5-turbo", messages=self.conversation_list
        )
        answer = response.choices[0].message["content"]

        self.conversation_list.append({"role": "assistant", "content": answer})
        self.show_conversation(self.conversation_list)

        cost = total_counts(response)
        self.costs_list.append(cost)
        return answer
        print()

In [155]:
conversation_list = [
    {
        "role": "system",
        "content": "你是一个深度学习、机器学习、数学、计算机科学领域的实体标注专员，给定字符串列表，请依次找出其中包含的深度学习、机器学习、数学、计算机科学领域的实体，并返回该实体在原字符串中的表述。除结果外，不要返回任何其他内容。如输入'MLP asfasdfasdf'，返回'MLP'",
    }
]

bot = Chat(conversation_list)

In [156]:
text = str(tuple(list(entity_set)[0:10]))

In [157]:
text

"('Hyperparameter', 'Family', 'Computer scientist', 'Weak supervision', 'Dream', 'Nonparametric statistics', 'Journal of the Royal Statistical Society', 'Rodent', 'Maximum likelihood estimation', 'Norm (mathematics)')"

In [159]:
answer = bot.ask('Hyperparameter')

bot.conversation_list


👽: Hyperparameter


tokens: 197, cost: 0.00295


[{'role': 'system',
  'content': "你是一个深度学习、机器学习、数学、计算机科学领域的实体标注专员，给定字符串列表，请依次找出其中包含的深度学习、机器学习、数学、计算机科学领域的实体，并返回该实体在原字符串中的表述。除结果外，不要返回任何其他内容。如输入'MLP asfasdfasdf'，返回'MLP'"},
 {'role': 'user',
  'content': "('Hyperparameter', 'Family', 'Computer scientist', 'Weak supervision', 'Dream', 'Nonparametric statistics', 'Journal of the Royal Statistical Society', 'Rodent', 'Maximum likelihood estimation', 'Norm (mathematics)')"},
 {'role': 'assistant',
  'content': 'Maximum likelihood estimation\nNorm (mathematics)'},
 {'role': 'user', 'content': 'Hyperparameter'},
 {'role': 'assistant', 'content': 'Hyperparameter'}]

In [None]:
res_tiny = [i.strip() for i in re.sub('"|}|{||\[|\]', "", answer).split(",")]

In [102]:
list(entity_set)[:100]

['feedforward neural networks',
 'SETTLEMENTS',
 'CPUs',
 'suppressing variance',
 'REINFORCE TAMADRAalgorithm',
 'Gradient descent',
 'reinforcement learning',
 'Laplacian pyramid',
 'tree structure',
 'sparse connectivity',
 'Machine能',
 'SLOW feature analysis',
 'stochastic autoencoder',
 'IEEE Transactions on Automatic Control',
 'Deepbelief network',
 'encoding schemes',
 'functional',
 'Bernoulli',
 'Bayesian network',
 'summation',
 'GENERATIVE MODELS',
 'minimally large',
 'maximize',
 'back-propagation algorithm',
 'physical energy',
 'annealed importance sampling',
 '2.4 Stochastic Encoders and Decoders',
 'IPv9',
 'deep autoencoder',
 'Jarzynski',
 'minimization problem',
 '6.2 Sigmoid Units for Bernoulli Output Distributions',
 'encoder',
 'embodying',
 'Restricted Boltzmann Machine',
 'Atkins Medal',
 'minimally posterior probability',
 'hidden Markov model',
 'Memory networks',
 'parametrize',
 'kernels',
 'tree',
 'precision',
 'RECurrent neural network',
 'probability',