In [3]:
from sentence_transformers import SentenceTransformer
from collections import defaultdict
import pickle
import numpy as np

class Utterance:
    def __init__(self, string, embedding):
        self.string = string
        self.embedding = embedding

class GraphModel:
    def __init__(self):
        self.model = SentenceTransformer("all-MiniLM-L6-v2")
        with open("dialogue_sim/graphs/one_serv_sbert_10.pkl", 'rb') as file:
            self.graph = pickle.load(file)
        with open("dialogue_sim/data/train_dials.pkl", "rb") as file:
            self.train_dials = pickle.load(file)
        self.transitions = self.train_dials[0].transitions

        self.cluster_to_utts = defaultdict(list)
        for dial in self.train_dials:
            for i in range(len(dial.utterances)):
                self.cluster_to_utts[dial.second_stage_clusters[i]].append(Utterance(dial.utterances[i], dial.lm_embeddings[i])) 

        self.cluster_to_embs = {}
        for i, utts in self.cluster_to_utts.items():
            embs = []
            for utt in utts:
                embs.append(utt.embedding)
            self.cluster_to_embs[i] = np.array(embs)

    def get_closest(self, target, utt_embs, k = 5):
        similarity = (utt_embs @ target.reshape(-1, 1)).ravel()
        closest = np.argsort(similarity)[:-k:-1]
        return closest

    def get_next_cluster(self, cur_cluster):
        return np.argmax(self.transitions[cur_cluster])
        
    def __call__(self, text):
        embedding = self.model.encode(text)
        one_stage_cluster = self.graph.one_stage_clustering._subclusters["USER"].predict_cluster(embedding[0]).id
        second_stage_cluster = self.graph.cluster_kmeans_labels[0][one_stage_cluster]
        next_cluster = self.get_next_cluster(second_stage_cluster)
        closest_idxs = self.get_closest(embedding, self.cluster_to_embs[next_cluster])
        closest_str = []
        for i in closest_idxs:
            closest_str.append(self.cluster_to_utts[next_cluster][i].string)
        return closest_str[0]

In [4]:
model = GraphModel()

ModuleNotFoundError: No module named 'graph_model'

In [4]:
model(["Hi, I would like to order a taxi"])

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


'The contact number for the taxi is 07648586609.'