# core

> The core classes and functionality for SentenceGraph

In [None]:
#| default_exp core

In [None]:
#| hide
from nbdev.showdoc import *

In [2]:
#|export
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
from typing import List
from enum import Enum
import pandas as pd
from collections import namedtuple


In [3]:
#| export
class Format(Enum):
  Python = 0
  Torch = 1
  Numpy = 2
  Pandas = 3

In [4]:
#| export
class TextNode(namedtuple('TextNode', 'nodeId nodeText')):
    """A dataype containing an id and text. Meant to make it easier to use SentenceGraph graphs in downstream applications."""
    pass

In [11]:
#| export
class SentenceGraph:
    def __init__(self, model_name: str = None, model: None = None):
        if model_name == None or model_name == "" or model == None:
            self.model = model = SentenceTransformer('all-MiniLM-L6-v2')
        elif model is not None:
            self.model = model
        elif model_name is not None:
          self.model = model = SentenceTransformer(model_name)

    def createGraph(self, sentences: List[TextNode], format: Format = Format.Python) -> List[List[float]]:
        # TODO: Wrap this all in a function to export a functional version of this as well.
        sentence_embeddings = self.model.encode([node.nodeText for node in sentences])

        graph_len = len(sentence_embeddings)
        similarity_graph = [[0 for x in range(graph_len)] for y in range(graph_len)]

        for i, embeddingA in enumerate(sentence_embeddings):
          for j, embeddingB in enumerate(sentence_embeddings):
            similarity = cos_sim(embeddingA, embeddingB)

            # TODO when we get to python 3.10+ switch to pattern matching
            if format == Format.Python or format == Format.Pandas:
              similarity = similarity.tolist()[0][0]
            elif format == Format.Torch:
              similarity = similarity[0]
            elif format == Format.Numpy:
              similarity = similarity[0].numpy()

            similarity_graph[i][j] = similarity
        
        if format == Format.Pandas:
          index = [node.nodeId for node in sentences]
          df = pd.DataFrame(similarity_graph, index = index,
                                          columns = index)
          return df
        else:
          return similarity_graph

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()