In [12]:
import json
from langchain_community.graphs.age_graph import AGEGraph
import os
from langchain_core.documents import Document
from langchain_postgres import PGVector
from langchain_postgres.vectorstores import PGVector
import psycopg2
import getpass
import os
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
import pickle
import itertools


database = {
    "database": "postgres",
    "user": "postgres",
    "password": "password",
    "host": "localhost",
    "port": "5432"
}

In [2]:
AGEGraph.refresh_schema = lambda self: None

graph = AGEGraph(graph_name="gnd", conf=database)

In [3]:
connection = psycopg2.connect(
    dbname=database["database"],
    user=database["user"],
    password=database["password"],
    host=database["host"],
    port=database["port"]
)

In [4]:
documents: list[Document] = []

for dirs, _, f in os.walk("llms4subjects/shared-task-datasets/TIBKAT/tib-core-subjects/data/dev/"):
    for file in f:
        with open(os.path.join(dirs, file)) as f:
            data = json.load(f)["@graph"]
            subjects = []
            article = None
            document = Document(page_content="", metadata={})
            for d in data:
                if "@type" in d:
                    document = Document(
                        page_content=str(d["abstract"]).replace("\"", "'"),
                        metadata={
                            "subjects": []
                        },
                        id=d["@id"]
                    )
                    if type(d["title"]) == str:
                        document.metadata["title"] = d["title"].replace("\"", "'")
                    elif type(d["title"]) == list:
                        document.metadata["title"] = d["title"][0].replace("\"", "'")
                else:
                    if type(d["sameAs"]) == str:
                        subjects.append(d["sameAs"].replace("\"", "'"))
                    elif type(d["sameAs"]) == list:
                        for s in d["sameAs"]:
                            subjects.append(s.replace("\"", "'"))
            document.metadata["subjects"] = subjects
            documents.append(document)

len(documents)

6980

In [5]:
from typing import List


if not os.getenv("OPENAI_API_KEY"):
    os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter your OpenAPI token: ")

embeddings = OpenAIEmbeddings(
    model="text-embedding-3-small",
)

if not os.path.exists("tibkat-core-dev.pkl"):
    document_embeddings: List[List[float]] = embeddings.embed_documents([d.page_content for d in documents])
    pickle.dump(document_embeddings, open("tibkat-core-dev.pkl", "wb"))
else:
    document_embeddings: List[List[float]] = pickle.load(open("tibkat-core-dev.pkl", "rb"))

In [13]:
vector_store = PGVector(
    embeddings=embeddings,
    collection_name="tibkat",
    connection=f"postgresql+psycopg2://{database['user']}:{database['password']}@{database['host']}:{database['port']}/{database['database']}",
    use_jsonb=True,
)

llm = ChatOpenAI(
    model="gpt-3.5-turbo",
)

In [30]:
import pprint

for document, embedding in zip(documents, document_embeddings):
    similar = vector_store.similarity_search_by_vector(embedding, k=10)
    similar_document_subjects = list(set(itertools.chain(*[similar.metadata.get("subjects", []) for similar in similar])))
    subject_names = graph.query(f"""
    MATCH (s:Subject)-[:ALTERNATIVE_NAME]->(an:AlternativeName), (s2:Subject)-[:RELATED]->(s), (s2:Subject)-[:ALTERNATIVE_NAME]->(an2)
    WHERE an.name IN {json.dumps(similar_document_subjects)}
    RETURN an2
    """)
    filtered_subject_names = [s['an2']['name'] for s in subject_names]

    llm_query = f"""
    You are an AI that has been trained on a large corpus of scientific articles. You have been asked to provide a list of subjects that are related to the following article: `{document.metadata['title']}`. The article is about `{document.page_content}`. The subjects that can be related to this article are: {filtered_subject_names}. Do not include any subjects that are not related to the article. Do not generate any new subjects. Use the subjects that are already in the database. Print the list of subjects as a JSON object. Print the list of subjects in the same way that they are presented. DO NOT translate them. Do not print anything else, just the list of subjects as a JSON object. Do not capitalize `subjects`.

    The subjects that are related to this article are:
    """

    json_schema = {
        "title": "subjects",
        "description": "The subjects that are related to the article",
        "type": "object",
        "properties": {
            "subjects": {
                "type": "array",
                "description": "The subjects that are related to the article",
                "items": {
                    "type": "string",
                },
            },
        },
        "required": ["subjects"],
    }

    llm.with_structured_output(json_schema)
    response = llm.invoke(llm_query)
    try:
        filtered_subjects = json.loads(response.content)["subjects"]
    except json.JSONDecodeError:
        print(response.content)
        continue

    subjects_codes = graph.query(f"""
    MATCH (s:Subject)-[:ALTERNATIVE_NAME]-(an:AlternativeName)
    WHERE an.name IN {json.dumps(filtered_subjects)}
    RETURN s
    """)

    real_subjects = graph.query(f"""
        MATCH (s:Subject)-[:ALTERNATIVE_NAME]-(an:AlternativeName)
        WHERE an.name IN {json.dumps(document.metadata['subjects'])}
        RETURN s
    """)

    pprint.pprint({
        "title": document.metadata["title"],
        "real_subjects": real_subjects,
        "predicted_subjects": subjects_codes
    })

    break

{'predicted_subjects': [{'s': {'classification_name': 'Methoden und Techniken '
                                                      'der empirischen '
                                                      'Sozialforschung, '
                                                      'Statistik in den '
                                                      'Sozialwissenschaften, '
                                                      'Mathematische Statistik',
                               'code': 'gnd:4153917-5'}},
                        {'s': {'classification_name': 'Erkenntnistheorie, '
                                                      'Logik',
                               'code': 'gnd:4280370-6'}},
                        {'s': {'classification_name': 'Erkenntnistheorie, '
                                                      'Logik',
                               'code': 'gnd:4426693-5'}},
                        {'s': {'classification_name': 'Stochastik, Operations '
       

In [None]:
query = f"""
    MATCH (s:Subject)-[:ALTERNATIVE_NAME]->(an:AlternativeName)
    WHERE an.name = 'Methodologie'
    RETURN s
"""

graph.query(query)


[]