In [3]:
from transformers import AutoTokenizer, AutoModel
import torch
import psycopg2
from tqdm import tqdm
import json
from psycopg2 import sql
database = {
    "database": "postgres",
    "user": "postgres",
    "password": "password",
    "host": "192.168.1.16",
    "port": "5432"
}

def connect_to_db():

    try:
        conn = psycopg2.connect(
            dbname=database["database"],
            user=database["user"],
            password=database["password"],
            host=database["host"],
            port=database["port"]
        )
    except Exception as e:
        print(e)
        raise e

    return conn

In [4]:
conn = connect_to_db()
conn.close()

In [5]:
from transformers import AutoTokenizer, AutoModel
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from transformers.optimization import get_scheduler
from tqdm import tqdm
import numpy as np
import ast
import urllib
model_name = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

In [6]:
class DocumentDataset(Dataset):
    @staticmethod
    def _randomly_sample_docs(amount: int):
        conn = connect_to_db()
        
        cursor = conn.cursor()
        cursor.execute("SET search_path TO ag_catalog;")
        cursor.execute(
            f"""
            SELECT * FROM (
                    SELECT * FROM cypher('gnd', $$
                    MATCH (d:Document)
                    RETURN d
                $$) AS (d agtype)
            ) AS subquery
            ORDER BY random()
            LIMIT {amount};
            """
        )

        documents = [json.loads(a[0].replace("::vertex", "")) for a in cursor.fetchall()]
        
        cursor.close()
        conn.close()
        print("Documents fetched!")
        return documents

    @staticmethod
    def _grab_subjects(documents: list):
        conn = connect_to_db()

        cursor = conn.cursor()
        cursor.execute("SET search_path TO ag_catalog;")
        cursor.execute("CREATE INDEX subj_id ON gnd.\"Subject\" USING gin (properties);")
        subjects = []
        for document in tqdm(documents, desc="Loading subjects..."):
            cursor.execute(f"""
                SELECT * FROM cypher('gnd', $$
                    MATCH (d:Document)-[:DOC_SUBJECT]->(s:Subject)
                    WHERE d.id = '{document["properties"]["id"]}'
                    RETURN s
                $$) AS (s agtype);
            """)
            subj = cursor.fetchall()
            doc_subjects = [json.loads(a[0].replace("::vertex", "")) for a in subj]

            subjects.append(doc_subjects)
        cursor.execute("DROP INDEX IF EXISTS gnd.\"subj_id\";")
        cursor.close()
        conn.close()

        return subjects

    # @staticmethod
    # def _grab_embeddings(subject_list: list[list[str]]):
        
    #     conn = connect_to_db()
    #     cursor = conn.cursor()
    #     cursor.execute("CREATE INDEX label_indx ON label_embeddings (label_code)")
    #     embeddings = []
    #     for subset in tqdm(subject_list, desc="Loading embeddings..."):
    #         embeddings_subset = []
    #         for subject in subset:
    #             sub_code = subject["properties"]["code"]
    #             cursor.execute("SELECT embedding FROM label_embeddings WHERE label_code = %s", (sub_code,))
    #             vector_string = cursor.fetchone()[0]
    #             vector = torch.tensor(ast.literal_eval(vector_string), dtype=torch.float32)
    #             embeddings_subset.append(vector)
            
    #         embeddings.append(embeddings_subset)
        
    #     cursor.execute("DROP INDEX IF EXISTS label_indx;")
    #     cursor.close()
    #     conn.close()
    #     return embeddings

    @staticmethod
    def _grab_embeddings(subject_list: list[list[str]]):
        conn = connect_to_db()
        cursor = conn.cursor()

        # Create index for faster lookups
        cursor.execute("CREATE INDEX label_indx ON label_embeddings (label_code)")
        
        embeddings = []
        for subset in tqdm(subject_list, desc="Loading embeddings..."):
        # for subset in subject_list:
            embeddings_subset = []
            for subject in subset:
                
                sub_code = urllib.parse.unquote(subject["properties"]["code"])

                # Fetch the vector from the database
                cursor.execute("SELECT embedding FROM label_embeddings WHERE label_code = %s;", (sub_code,))
                result = cursor.fetchone()
                
                if result is None:
                    # Handle missing label_code
                    print(f"Warning: No embedding found for label_code {sub_code}")
                    continue
                
                embedding_data = result[0]
                
                if embedding_data is None:
                    # Handle NULL embeddings
                    print(f"Warning: NULL embedding for label_code {sub_code}")
                    continue
                
                try:
                    # Convert to PyTorch tensor
                    if isinstance(embedding_data, list):  # If already a list-like object
                        vector = torch.tensor(embedding_data, dtype=torch.float32)
                    elif isinstance(embedding_data, str):  # If it's a string like "[0.1, 0.2, 0.3]"
                        vector = torch.tensor(ast.literal_eval(embedding_data), dtype=torch.float32)
                    else:
                        raise ValueError(f"Unexpected embedding format for label_code {sub_code}: {embedding_data}")
                except Exception as e:
                    print(f"Error processing embedding for label_code {sub_code}: {e}")
                    continue
                
                embeddings_subset.append(vector)
            
            embeddings.append(torch.stack(embeddings_subset).mean(dim=0))
        
        # Drop the index after processing
        cursor.execute("DROP INDEX IF EXISTS label_indx;")
        cursor.close()
        conn.close()
        
        return embeddings

    def __init__(self, size):
        self._documents = self._randomly_sample_docs(size)
        self._subjects = self._grab_subjects(self._documents)
        self._embeddings = self._grab_embeddings(self._subjects)
        self._len = size

    def __len__(self):
        return self._len

    def __getitem__(self, idx):
        return {
            "document": self._documents[idx],
            "label_embeddings": self._embeddings[idx]
        }
    

In [12]:
class RandomDocumentDataset(Dataset):
    @staticmethod
    def _randomly_sample_docs(amount: int):
        conn = connect_to_db()
        
        cursor = conn.cursor()
        cursor.execute("SET search_path TO ag_catalog;")
        cursor.execute(
            f"""
            SELECT * FROM (
                    SELECT * FROM cypher('gnd', $$
                    MATCH (d:Document)
                    RETURN d
                $$) AS (d agtype)
            ) AS subquery
            ORDER BY random()
            LIMIT {amount};
            """
        )

        documents = [json.loads(a[0].replace("::vertex", "")) for a in cursor.fetchall()]
        
        cursor.close()
        conn.close()
        print("Documents fetched!")
        return documents

    @staticmethod
    def _grab_subjects(documents: list):
        conn = connect_to_db()

        cursor = conn.cursor()
        cursor.execute("SET search_path TO ag_catalog;")
        cursor.execute("CREATE INDEX subj_id ON gnd.\"Subject\" USING gin (properties);")
        subjects = []
        for document in tqdm(documents, desc="Loading subjects..."):
            cursor.execute(f"""
                SELECT * FROM cypher('gnd', $$
                    MATCH (d:Document)-[:DOC_SUBJECT]->(s:Subject)
                    WHERE d.id = '{document["properties"]["id"]}'
                    RETURN s
                $$) AS (s agtype);
            """)
            subj = cursor.fetchall()
            doc_subjects = [json.loads(a[0].replace("::vertex", "")) for a in subj]

            subjects.append(doc_subjects)
        cursor.execute("DROP INDEX IF EXISTS gnd.\"subj_id\";")
        cursor.close()
        conn.close()

        return subjects

    def _sample_one_doc(self):
        document = self._randomly_sample_docs(1)
        subjects = self._grab_subjects(document)
        embeddings = self._grab_embeddings(subjects)

        return (document, subjects, embeddings)

    @staticmethod
    def _grab_embeddings(subject_list: list[list[str]]):
        conn = connect_to_db()
        cursor = conn.cursor()

        # Create index for faster lookups
        cursor.execute("CREATE INDEX label_indx ON label_embeddings (label_code)")
        
        embeddings = []
        for subset in tqdm(subject_list, desc="Loading embeddings..."):
        # for subset in subject_list:
            embeddings_subset = []
            for subject in subset:
                
                sub_code = urllib.parse.unquote(subject["properties"]["code"])

                # Fetch the vector from the database
                cursor.execute("SELECT embedding FROM label_embeddings WHERE label_code = %s;", (sub_code,))
                result = cursor.fetchone()
                
                if result is None:
                    # Handle missing label_code
                    print(f"Warning: No embedding found for label_code {sub_code}")
                    continue
                
                embedding_data = result[0]
                
                if embedding_data is None:
                    # Handle NULL embeddings
                    print(f"Warning: NULL embedding for label_code {sub_code}")
                    continue
                
                try:
                    # Convert to PyTorch tensor
                    if isinstance(embedding_data, list):  # If already a list-like object
                        vector = torch.tensor(embedding_data, dtype=torch.float32)
                    elif isinstance(embedding_data, str):  # If it's a string like "[0.1, 0.2, 0.3]"
                        vector = torch.tensor(ast.literal_eval(embedding_data), dtype=torch.float32)
                    else:
                        raise ValueError(f"Unexpected embedding format for label_code {sub_code}: {embedding_data}")
                except Exception as e:
                    print(f"Error processing embedding for label_code {sub_code}: {e}")
                    continue
                
                embeddings_subset.append(vector)
            
            embeddings.append(torch.stack(embeddings_subset).mean(dim=0))
        
        # Drop the index after processing
        cursor.execute("DROP INDEX IF EXISTS label_indx;")
        cursor.close()
        conn.close()
        
        return embeddings

    def __init__(self, size):
        self._len = size

    def __len__(self):
        return self._len

    def __getitem__(self, idx):
        document, subjects, embeddings = self._sample_one_doc()
        return {
            "document": document[0],
            "label_embeddings": embeddings[0]
        }
    

In [None]:
test_dataset = RandomDocumentDataset(1)
print(test_dataset[0]["label_embeddings"])

Documents fetched!


Loading subjects...: 100%|██████████| 1/1 [00:00<00:00,  8.45it/s]


In [8]:
class FineTunedModel(nn.Module):
    def __init__(self, base_model):
        super(FineTunedModel, self).__init__()
        self.base_model = base_model

    def _mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output[0] 
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


    def forward(self, input_ids, attention_mask):
        model_output = self.base_model(input_ids = input_ids, attention_mask=attention_mask)
        sentence_embeddings = self._mean_pooling(model_output, attention_mask)

        return sentence_embeddings

embedding_model = FineTunedModel(model)

In [9]:
cosine_loss = nn.CosineEmbeddingLoss()

from torch.optim import AdamW

optimizer = AdamW(embedding_model.parameters(), lr=2e-5)

In [10]:
from torch.nn.utils.rnn import pad_sequence

dataset = DocumentDataset(size=1000)

def custom_collate_fn(batch):
    documents = [item["document"] for item in batch]
    label_embeddings = [item["label_embeddings"] for item in batch]
    
    return {
        "documents": documents,
        "label_embeddings": label_embeddings,
    }

batch_size = 16
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate_fn)

Documents fetched!


Loading subjects...: 100%|██████████| 1000/1000 [01:46<00:00,  9.39it/s]
Loading embeddings...: 100%|██████████| 1000/1000 [00:30<00:00, 32.70it/s]


In [11]:
from urllib.parse import unquote

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
embedding_model.to(device)

epochs = 5
embedding_model.train()
for i in tqdm(range(epochs)):
    epoch_loss = 0
    for batch in data_loader:
        loss=0
        documents = batch["documents"]
        label_embeddings = batch["label_embeddings"]
        input_text = [unquote(doc["properties"]["title"]) + " " + unquote( doc["properties"]["content"]) for doc in documents]
        inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
        input_ids = inputs["input_ids"].to(device)
        attention_mask = inputs["attention_mask"].to(device)

        doc_embeddings = embedding_model(input_ids, attention_mask)
        
        for i , (doc_embedding, label_embedding) in enumerate(zip(doc_embeddings, label_embeddings)):
            avg_embedding = label_embedding.to(device)

            target = torch.tensor(1.0, device=device)
            loss += cosine_loss(doc_embedding, avg_embedding, target)
        
        loss /= len(doc_embeddings)

        epoch_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    print(f"Epoch {i + 1}/{epochs}, Avg Loss: {epoch_loss/(len(dataset)/batch_size):.4f}")

torch.save(embedding_model, "fine_tuned_model_complete.pth")

 10%|█         | 1/10 [09:56<1:29:28, 596.51s/it]

Epoch 8/10, Avg Loss: 0.2452


 20%|██        | 2/10 [19:49<1:19:14, 594.29s/it]

Epoch 8/10, Avg Loss: 0.1783


 30%|███       | 3/10 [30:02<1:10:20, 602.99s/it]

Epoch 8/10, Avg Loss: 0.1539


 40%|████      | 4/10 [40:08<1:00:24, 604.16s/it]

Epoch 8/10, Avg Loss: 0.1375


 50%|█████     | 5/10 [50:02<50:02, 600.44s/it]  

Epoch 8/10, Avg Loss: 0.1232


 60%|██████    | 6/10 [1:00:08<40:09, 602.26s/it]

Epoch 8/10, Avg Loss: 0.1109


 70%|███████   | 7/10 [1:09:50<29:47, 595.72s/it]

Epoch 8/10, Avg Loss: 0.1019


 80%|████████  | 8/10 [1:19:34<19:44, 592.10s/it]

Epoch 8/10, Avg Loss: 0.0940


 90%|█████████ | 9/10 [1:29:27<09:52, 592.31s/it]

Epoch 8/10, Avg Loss: 0.0873


100%|██████████| 10/10 [1:39:24<00:00, 596.50s/it]

Epoch 8/10, Avg Loss: 0.0811



