In [3]:
import torch
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForMaskedLM, pipeline, AutoModel

from sklearn.metrics.pairwise import cosine_similarity

from datasets import load_dataset
import math
import json

import numpy as np
from deep_translator import GoogleTranslator

import nltk
from nltk.corpus import stopwords

In [4]:
def read_txt(file_path):
    prompts = {}
    current_key = None

    with open(file_path, 'r', encoding='utf-8') as file:
        for line in file:
            line = line.strip()
            
            if not line:  # Skip empty lines
                continue

            if line.startswith('topic_'):
                current_key = line[6:]
                prompts[current_key] = []
            elif current_key:
                prompts[current_key].append(line)

    return prompts

In [5]:
questions_path = '../embedding/testQuestions.txt'
contexts_path = '../embedding/contexts.txt'
embContexts_path = '../embedding/embContexts.txt'

#load test_prompts, contexts, and enlarged_contexts_for_classification
prompts = read_txt(questions_path)
contexts = read_txt(contexts_path)
contextsEmb = read_txt(embContexts_path)

In [6]:

# model_name = "ufal/robeczech-base" 
# tokenizer = AutoTokenizer.from_pretrained(model_name)
# model = AutoModelForMaskedLM.from_pretrained(model_name)
# tokenizer = AutoTokenizer.from_pretrained('DeepPavlov/bert-base-bg-cs-pl-ru-cased')
# model = AutoModel.from_pretrained('DeepPavlov/bert-base-bg-cs-pl-ru-cased')

In [7]:
#pick a right model

tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
feature_extraction_pipeline = pipeline('feature-extraction', model=model, tokenizer=tokenizer) 

In [38]:
class EmbeddingProcessor:
    def __init__(self, pipeline=feature_extraction_pipeline, chunk_size=200):
        self.pipeline = pipeline
        self.chunk_size = chunk_size
        
        nltk.download('stopwords')
        self.stop_words_en = set(stopwords.words('english'))
    
    def _preprocess_text(self, sentence):
        filtered_tokens = [word for word in sentence.split() if word.lower() not in self.stop_words_en]
        return ' '.join(filtered_tokens)
    
    def get_embedding(self, prompt, preprocess=False):
        if preprocess:
            prompt = self._preprocess_text(prompt)
        chunks = [prompt[i:i + self.chunk_size] for i in range(0, len(prompt), self.chunk_size)]
        chunk_embeddings = []

        for chunk in chunks:
            chunk_embedding = self.pipeline(chunk)
            chunk_embedding = np.mean(chunk_embedding[0], axis=0)
            chunk_embeddings.append(chunk_embedding)

        embedding = np.mean(chunk_embeddings, axis=0).reshape(1, -1)
        
        return embedding



In [51]:
class PromptProcessor(EmbeddingProcessor):
    def __init__(self, pipeline=feature_extraction_pipeline, source_lang='cs', target_lang='en', chunk_size=200):
        super().__init__(pipeline=pipeline, chunk_size=chunk_size)
        self.translator = GoogleTranslator()
        self.source_lang = source_lang
        self.target_lang = target_lang
    
    def _translate(self, prompt):
        return self.translator.translate(prompt, src=self.source_lang, dest=self.target_lang)

    def _get_prompt_embedding_class(self, prompt_embedding, embeddings):
        prompt_class = None
        max_sim = -1

        for emb_name, emb_t in embeddings.items():
            sim = cosine_similarity(prompt_embedding, emb_t)
            if sim > max_sim:
                prompt_class = emb_name
                max_sim = sim

        return prompt_class
    
    #translate, preprocess, get embedding, and return class of a prompt
    def process_prompt(self, prompt, embeddings, preprocess=True, translate=True):
        if translate:
            prompt = self._translate(prompt)
        prompt_embedding = self.get_embedding(prompt, preprocess)
        prompt_class = self._get_prompt_embedding_class(prompt_embedding, embeddings)
        
        return prompt_class

In [46]:
embeddingProcessor = EmbeddingProcessor()

[nltk_data] Downloading package stopwords to /home/martin/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [47]:
embeddings = {}
for key, context in contextsEmb.items():
    context_embedding = embeddingProcessor.get_embedding(context[0], preprocess=True)
    embeddings[key] = context_embedding
    

In [48]:
prompt_processor = PromptProcessor()

[nltk_data] Downloading package stopwords to /home/martin/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [49]:
question_embeddings = {}
for key, topic in prompts.items():
    question_embeddings[key] = {}
    for prompt in topic:
        prompt_class = prompt_processor.process_prompt(prompt, embeddings, preprocess=True)
        question_embeddings[key][prompt] = prompt_class


In [50]:
#results on chatGPT generated questoins, some questions have overlap
for topic, questions in question_embeddings.items():
    print(f"Topic: {topic}")
    for question, prompt_class in questions.items():
        print(f"Question: {question}")
        print(f"Class: {prompt_class}")
        print("\n")

Topic: ukraine
Question: Kdy začala konflikt na Ukrajině a jaký byl jeho původ?,
Class: ukraine


Question: Jaký byl vliv anexe Krymu na eskalaci konfliktu?,
Class: ukraine


Question: Které země podporují Ukrajinu a které podporují separatisty?,
Class: ukraine


Question: Jaká je aktuální humanitární situace v postižených oblastech?,
Class: israel_palestine


Question: Jaké jsou diplomatické snahy o řešení konfliktu a mírový proces?,
Class: israel_palestine


Question: Jaký je vliv války na civilní obyvatelstvo v dané oblasti?,
Class: ukraine


Question: Které mezinárodní organizace se aktivně angažují ve snaze ukončit konflikt?,
Class: israel_palestine


Question: Jaké jsou následky války na ekonomiku Ukrajiny a okolních zemí?,
Class: ukraine


Question: Existuje nějaký plán obnovy a rekonstrukce po skončení konfliktu?,
Class: israel_palestine


Question: Jaký je postoj veřejnosti v různých částech světa k ukrajinskému konfliktu?,
Class: ukraine


Topic: palestine
Question: Jak vznik

In [98]:
import pickle

In [99]:
with open('contexts_embeddings.pkl', 'wb') as f:
    pickle.dump(embeddings, f)

In [100]:
with open('contexts.pkl', 'wb') as f:
    pickle.dump(contexts, f)