In [1]:
from transformers import AutoTokenizer, AutoModel
import torch
import psycopg2
from tqdm import tqdm
import json
from psycopg2 import sql
database = {
    "database": "postgres",
    "user": "postgres",
    "password": "password",
    "host": "192.168.1.16",
    "port": "5432"
}

def connect_to_db():
    return psycopg2.connect(
        dbname=database["database"],
        user=database["user"],
        password=database["password"],
        host=database["host"],
        port=database["port"]
    )

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from transformers import AutoTokenizer, AutoModel
import torch
import base64


def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] 
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
model = AutoModel.from_pretrained('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')

In [3]:
conn = connect_to_db()
curs = conn.cursor()

curs.execute(f"""
    SET search_path TO ag_catalog;
    SELECT * FROM cypher('gnd', $$
        MATCH (s:Subject)
        RETURN s
    $$) AS (result agtype);
""")
labels = curs.fetchall()

labels = [json.loads(a[0].replace("::vertex", "")) for a in labels]
conn.close()

print(len(labels))
print(labels[23])

204740
{'id': 844424930327868, 'label': 'Subject', 'properties': {'code': 'gnd:1022027832', 'name': 'Memristor', 'classification_name': 'Elektronik, Nachrichtentechnik'}}


In [6]:
import itertools
import urllib
def convert_to_embeddings(text):
    tokenized_text = tokenizer(text, padding=True, truncation=True, return_tensors='pt')

    with torch.no_grad():
        model_output = model(**tokenized_text)

    return mean_pooling(model_output, tokenized_text['attention_mask']).tolist()

conn = connect_to_db()
cursor = conn.cursor()
batch_size = 10
total_batches = len(labels)//batch_size
commit_interval = 100
for i, batch in enumerate(tqdm(itertools.batched(labels, batch_size), total= total_batches)):
    text_data = []
    for label in batch:
        if 'name' in label['properties'].keys():
            text = urllib.parse.unquote(label['properties']['name']) + " " + urllib.parse.unquote(label['properties']['classification_name'])
        else:
            text = urllib.parse.unquote(label['properties']['classification_name'])

        text_data.append(text)

    try:
        print(text_data)
        embeddings = convert_to_embeddings(text_data)
        print(text_data)
        query_params = [(urllib.parse.unquote(doc['properties']['code']), "[" + ",".join(map(str, embedding)) + "]") for doc, embedding in zip(batch, embeddings)]

        cursor.executemany("INSERT INTO label_embeddings (label_code, embedding) VALUES (%s, %s::vector)", query_params)

    except Exception as e:
        print(e)
        conn.rollback()
        break
    
    if i % commit_interval == 0:
        conn.commit()

conn.commit()
cursor.close()
conn.close()


  0%|          | 0/20474 [00:00<?, ?it/s]

['DPSK Elektronik, Nachrichtentechnik', 'Frequenzdiversity Elektronik, Nachrichtentechnik', 'Antennendiversity Elektronik, Nachrichtentechnik', 'Empfängerdiversity Elektronik, Nachrichtentechnik', 'NAND-Gatter Elektronik, Nachrichtentechnik', 'DVB-SH Elektronik, Nachrichtentechnik', 'MediaFLO Elektronik, Nachrichtentechnik', 'Pentacon AK 8 Elektronik, Nachrichtentechnik', 'Kabelkopfstelle Elektronik, Nachrichtentechnik', 'DSSS Elektronik, Nachrichtentechnik']
['DPSK Elektronik, Nachrichtentechnik', 'Frequenzdiversity Elektronik, Nachrichtentechnik', 'Antennendiversity Elektronik, Nachrichtentechnik', 'Empfängerdiversity Elektronik, Nachrichtentechnik', 'NAND-Gatter Elektronik, Nachrichtentechnik', 'DVB-SH Elektronik, Nachrichtentechnik', 'MediaFLO Elektronik, Nachrichtentechnik', 'Pentacon AK 8 Elektronik, Nachrichtentechnik', 'Kabelkopfstelle Elektronik, Nachrichtentechnik', 'DSSS Elektronik, Nachrichtentechnik']
duplicate key value violates unique constraint "label_embeddings_label




In [5]:
conn.rollback()
conn.close()

InterfaceError: connection already closed