In [12]:
import json
from langchain_community.graphs.age_graph import AGEGraph
import os
from langchain_core.documents import Document
from langchain_postgres import PGVector
from langchain_postgres.vectorstores import PGVector
import psycopg2
import getpass
import os
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
import pickle
import itertools


database = {
    "database": "postgres",
    "user": "postgres",
    "password": "password",
    "host": "localhost",
    "port": "5432"
}

In [13]:
AGEGraph.refresh_schema = lambda self: None

graph = AGEGraph(graph_name="gnd", conf=database)

In [14]:
connection = psycopg2.connect(
    dbname=database["database"],
    user=database["user"],
    password=database["password"],
    host=database["host"],
    port=database["port"]
)

In [15]:
documents: list[Document] = []

for dirs, _, f in os.walk("llms4subjects/shared-task-datasets/TIBKAT/tib-core-subjects/data/dev/"):
    for file in f:
        with open(os.path.join(dirs, file)) as f:
            data = json.load(f)["@graph"]
            subjects = []
            article = None
            document = Document(page_content="", metadata={})
            for d in data:
                if "@type" in d:
                    document = Document(
                        page_content=str(d["abstract"]).replace("\"", "'"),
                        metadata={
                            "subjects": []
                        },
                        id=d["@id"]
                    )
                    if type(d["title"]) == str:
                        document.metadata["title"] = d["title"].replace("\"", "'")
                    elif type(d["title"]) == list:
                        document.metadata["title"] = d["title"][0].replace("\"", "'")
                else:
                    if type(d["sameAs"]) == str:
                        subjects.append(d["sameAs"].replace("\"", "'"))
                    elif type(d["sameAs"]) == list:
                        for s in d["sameAs"]:
                            subjects.append(s.replace("\"", "'"))
            document.metadata["subjects"] = subjects
            documents.append(document)

len(documents)

6980

In [16]:
from typing import List


if not os.getenv("OPENAI_API_KEY"):
    os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter your OpenAPI token: ")

embeddings = OpenAIEmbeddings(
    model="text-embedding-3-small",
)

if not os.path.exists("tibkat-core-dev.pkl"):
    document_embeddings: List[List[float]] = embeddings.embed_documents([d.page_content for d in documents])
    pickle.dump(document_embeddings, open("tibkat-core-dev.pkl", "wb"))
else:
    document_embeddings: List[List[float]] = pickle.load(open("tibkat-core-dev.pkl", "rb"))

In [6]:
vector_store = PGVector(
    embeddings=embeddings,
    collection_name="tibkat",
    connection=f"postgresql+psycopg2://{database['user']}:{database['password']}@{database['host']}:{database['port']}/{database['database']}",
    use_jsonb=True,
)

llm = ChatOpenAI(
    model="gpt-4o",
)

# json_schema = {
#         "title": "subjects",
#         "description": "The subjects that are related to the article",
#         "type": "object",
#         "properties": {
#             "subjects": {
#                 "type": "array",
#                 "description": "The subjects that are related to the article",
#                 "items": {
#                     "type": "string",
#                 },
#             },
#         },
#         "required": ["subjects"],
#     }

# llm.with_structured_output(json_schema)

In [31]:
from typing import List
from langchain_core.documents.base import Document
import tqdm

extracted_subjects = {}

for document, embedding in tqdm.tqdm(zip(documents, document_embeddings), total=len(documents)):
    similar: List[Document] = vector_store.similarity_search_by_vector(embedding, k=20)
    similar_document_subjects = list(set(itertools.chain(*[similar.metadata.get("subjects", []) for similar in similar])))
    filtered_subject_names = {}
    # for s in similar_document_subjects:
    #     names = graph.query(f"""
    #     MATCH (s:Subject)-[:ALTERNATIVE_NAME]->(an:AlternativeName), (s2:Subject)-[:RELATED]->(s), (s2:Subject)-[:ALTERNATIVE_NAME]->(an2)
    #     WHERE an.name = "{s}"
    #     RETURN an2, s2, an, s
    #     """)
    #     for n in names:
    #         filtered_subject_names[n['an2']['name']] = n['s2']['code']
    #         filtered_subject_names[n['an']['name']] = n['s']['code']
    for s in similar_document_subjects:
        names = graph.query(f"""
        MATCH (s:Subject)-[:ALTERNATIVE_NAME]->(an:AlternativeName)
        WHERE an.name = "{s}"
        RETURN an, s
        """)
        for n in names:
            filtered_subject_names[n['an']['name']] = n['s']['code']

    unfiltered_llm_subject_names = filtered_subject_names

    real_subjects = {}
    for s in document.metadata["subjects"]:
        names = graph.query(f"""
        MATCH (s:Subject)-[:ALTERNATIVE_NAME]->(an:AlternativeName)
        WHERE an.name = "{s}"
        RETURN an, s
        """)
        for n in names:
            real_subjects[n['an']['name']] = n['s']['code']

    max_possibility = 0
    for s in real_subjects.keys():
        if s in filtered_subject_names.keys():
            max_possibility += 1

    llm_query = f"""
You are an AI that has been trained on a large corpus of scientific articles. You have been asked to provide a list of subjects that are related to the following article: `{document.metadata['title']}`. The article is about `{document.page_content}`. The subjects that can be related to this article are: {filtered_subject_names.keys()}. Out of these subjects, list the ones that ARE related. Output the subjects as a JSON list. Output ONLY the JSON list. Try to include about 5 or less subjects.

The subjects that are related to this article are:
```json
"""

    response = llm.invoke(llm_query)
    try:
        content = response.content.replace('```', '')
        filtered_subjects = json.loads(content)
        if type(filtered_subjects) == dict:
            filtered_subjects = filtered_subjects["subjects"]
        filtered_subject_names = {k: v for k, v in filtered_subject_names.items() if k in filtered_subjects}
    except (json.JSONDecodeError, KeyError, Exception) as e:
        print(response.content)
        print(e)

    did = 0
    for s in real_subjects.keys():
        if s in filtered_subject_names.keys():
            did += 1

    print((len(real_subjects), len(unfiltered_llm_subject_names), len(filtered_subject_names), (max_possibility, did)))

    extracted_subjects[document.id] = {
        "real": list(set(real_subjects.values())),
        "before_filtering": list(set(unfiltered_llm_subject_names.values())),
        "extracted": list(set(filtered_subject_names.values()))
    }

  0%|          | 1/6980 [00:06<12:41:36,  6.55s/it]

(5, 30, 5, (5, 0))


  0%|          | 1/6980 [00:12<23:49:18, 12.29s/it]


KeyboardInterrupt: 

In [88]:
extracted_subjects

{'https://www.tib.eu/de/suchen/id/TIBKAT%3A1831632292': {'real': ['gnd:4015999-1',
   'gnd:4252654-1',
   'gnd:4139716-2',
   'gnd:4066528-8',
   'gnd:4124477-1'],
  'before_filtering': ['gnd:4129281-9',
   'gnd:4015999-1',
   'gnd:4056243-8',
   'gnd:4172189-5',
   'gnd:4252654-1',
   'gnd:4014894-4',
   'gnd:4139716-2',
   'gnd:4066528-8',
   'gnd:4124477-1'],
  'extracted': ['gnd:4015999-1',
   'gnd:4056243-8',
   'gnd:4139716-2',
   'gnd:4066528-8',
   'gnd:4124477-1']},
 'https://www.tib.eu/de/suchen/id/TIBKAT%3A1831636492': {'real': ['gnd:4078523-3',
   'gnd:4061619-8'],
  'before_filtering': ['gnd:4326464-5',
   'gnd:4511937-5',
   'gnd:4050129-2',
   'gnd:4016397-0',
   'gnd:4002963-3',
   'gnd:4061619-8'],
  'extracted': ['gnd:4326464-5',
   'gnd:4002963-3',
   'gnd:4050129-2',
   'gnd:4016397-0']},
 'https://www.tib.eu/de/suchen/id/TIBKAT%3A1831640708': {'real': ['gnd:4112736-5',
   'gnd:4148259-1'],
  'before_filtering': ['gnd:4066528-8',
   'gnd:4017195-4',
   'gnd:4011152-

In [30]:
avg_accuracy = 0
avg_recall = 0
avg_precision = 0

for pair in extracted_subjects.values():
    accuracy = len(set(pair['real']).intersection(pair['extracted'])) / (len(pair['real']) or 1)
    recall = len(set(pair['real']).intersection(pair['extracted'])) / (len(pair['before_filtering']) or 1)
    precision = len(set(pair['real']).intersection(pair['extracted'])) / (len(pair['extracted']) or 1)
    print(f"Accuracy: {accuracy}, Recall: {recall}, Precision: {precision}")
    avg_accuracy += accuracy
    avg_recall += recall
    avg_precision += precision

avg_accuracy /= len(extracted_subjects)
avg_recall /= len(extracted_subjects)
avg_precision /= len(extracted_subjects)

print(f"Average accuracy: {avg_accuracy}, Average recall: {avg_recall}, Average precision: {avg_precision}")

Accuracy: 1.0, Recall: 0.16666666666666666, Precision: 0.16666666666666666
Accuracy: 1.0, Recall: 0.043478260869565216, Precision: 0.043478260869565216
Accuracy: 1.0, Recall: 0.07692307692307693, Precision: 0.07692307692307693
Accuracy: 1.0, Recall: 0.07317073170731707, Precision: 0.07317073170731707
Accuracy: 1.0, Recall: 0.13043478260869565, Precision: 0.13043478260869565
Accuracy: 1.0, Recall: 0.09090909090909091, Precision: 0.09090909090909091
Accuracy: 0.25, Recall: 0.05263157894736842, Precision: 0.05263157894736842
Accuracy: 1.0, Recall: 0.15384615384615385, Precision: 0.15384615384615385
Accuracy: 1.0, Recall: 0.3333333333333333, Precision: 0.3333333333333333
Accuracy: 1.0, Recall: 0.07692307692307693, Precision: 0.07692307692307693
Accuracy: 1.0, Recall: 0.1875, Precision: 0.1875
Average accuracy: 0.9318181818181818, Average recall: 0.12598334115766774, Average precision: 0.12598334115766774
