In [1]:
test_text = """To my Chris, I have been thinking
about how I could possibly tell you
how much you mean to me. I remember
when I first started to fall in
love with you like it was last
night. Lying naked beside you in
that tiny apartment, it suddenly
hit me that I was part of this
whole larger thing, just like our
parents, and our parents' parents.
Before that I was just living my
life like I knew everything, and
suddenly this bright light hit me
and woke me up. That light was you."""

# Text -> chunks -> sentences -> concepts

In [2]:
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
from langchain_openai import ChatOpenAI
from langchain.prompts.chat import (
    ChatPromptTemplate,
    HumanMessagePromptTemplate,
)
from langchain.schema import SystemMessage
import spacy
import json
from const import *
from prompts import *
from utils import *
import uuid

In [3]:
nlp = spacy.load("en_core_web_md")

In [4]:
def chunk_text():
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=250,
        chunk_overlap=20,
        length_function=len,
        is_separator_regex=False,
        separators=['. ', '.\n']
    )
    doc = Document(page_content=test_text, metadata={"title": "Poem"})
    pages = text_splitter.split_documents([doc])
    return [(page.page_content, page.metadata) for page in pages]

In [5]:
def extract_triplet(text, context):
    # 1) create langchain callback OpenAI
    # 2) get JSON string representing [(c1, c2, relation)]
    # 3) convert to JSON
    # 4) return JSON = [(c1, c2, relation)]
    chat = ChatOpenAI(
        api_key=OPENAI_API_KEY,
        model_name="gpt-3.5-turbo-0125",
        temperature=0,
        max_tokens=1000,
    )
    system_message_prompt = SystemMessage(content=system_message)
    human_message_prompt = HumanMessagePromptTemplate.from_template(task_message)
    chat_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt])
    messages = chat_prompt.format_prompt(
        context=context,
        target=text,
    ).to_messages()
    res = chat.invoke(messages)
    
    res = str(res.content).strip()
    res = '[' + str(res) if res[0] != '[' else str(res)
    res = add_missing_commas(res)
    res = res.replace(',,', ',')
    try:
        json_res = json.loads(res)
    except:
        print('-----')
        print(res)
        return []

    for r in json_res:
        r['edge'] = r.pop('relation')
        r['id'] = str(uuid.uuid4())

    return json_res

In [9]:
def llm_extraction():
    # (Concept1, Concept2, relation)
    triplets = []

    res = chunk_text()
    for text, metadata in res:
        doc = nlp(text)
        for sent in doc.sents:
            sentence_text = sent.text
            sentence_text += "." if sent.text[-1] != "." else ""
            triplet = extract_triplet(sentence_text, text)
            # print(triplet)
            triplets.extend(triplet)
    
    df, G = get_dataframe_graph(triplets)
    print(df)
    df.to_csv('poem.csv')

    

In [10]:
llm_extraction()

          source sourceConceptType              target targetConceptType  \
0              I            entity               Chris            entity   
1              I            entity            to think             event   
2              I            entity                 you            entity   
3       to think             event               Chris            entity   
4       to think             event             meaning  abstract concept   
5        meaning  abstract concept                 you            entity   
6              I            entity            remember             event   
7       remember             event        fall in love             event   
8   fall in love             event                 you            entity   
9   fall in love             event          last night             event   
10             I            entity                 you            entity   
11           you            entity              beside             event   
12          

In [12]:
from pyvis.network import Network

df = pd.read_csv('poem.csv')
df = df[["source", "target", "edge"]]
G = nx.from_pandas_edgelist(df, 'source', 'target', edge_attr=True, create_using=nx.MultiDiGraph())

# Plot Network
net = Network(height="750px", width="100%")
# net.from_nx(G)
for _, row in df.iterrows():
    net.add_node(row['source'], label=row['source'])
    net.add_node(row['target'], label=row['target'])
    net.add_edge(row['source'], row['target'], label=row['edge'])

net.show("poem.html", notebook=False)

poem.html
