In [1]:
from enum import Enum
import pexpect
import re

class Mode(Enum):
    PREDICT = 1
    EMBEDDING = 2
    SIMILAR_N = 3

class StarSpaceInterface(object): 
    
    def __init__(self, mode, **kwarg):
        self.mode = mode
        if self.mode == Mode.PREDICT:
            cmd = '{} {} {} {}'.format(kwarg["starspace_path"] + "query_predict", kwarg["model_path"], kwarg["k"] if "k" in kwarg else 1, kwarg["baseDocs"] if "baseDocs" in kwarg else "")
            self.child = pexpect.spawn(cmd)
            self.child.expect('Enter some text: ')
        elif self.mode == Mode.EMBEDDING:
            cmd = '{} {}'.format(kwarg["starspace_path"] + 'embed_doc', kwarg["model_path"])
            self.child = pexpect.spawn(cmd)
            self.child.expect('Input your sentence / document now:')
        elif self.mode == Mode.SIMILAR_N:
            cmd = '{} {} {}'.format(kwarg["starspace_path"] + 'query_nn', kwarg["model_path"], kwarg["k"] if "k" in kwarg else 5)
            self.child = pexpect.spawn(cmd)
            self.child.expect('Enter some text: ')
            
    def predict(self, text):
        if self.mode == Mode.PREDICT:
            self.child.sendline(text)
            self.child.expect('Enter some text: ')
            stdout = self.child.before.decode('utf-8')
            lines = stdout.split('\n')
            results = []
            for line in lines:
                search = re.search(r'(\d)\[(\d+\.\d+)\]:\s(.*)\s+', line)
                if search:
                    label = search.group(3).strip()
                    proba = float(search.group(2).strip())
                    results.append(dict(label=label, proba=proba))
            return results
        
        elif self.mode == Mode.EMBEDDING:
            self.child.sendline(text)
            self.child.expect(text + r"\r\n")
            self.child.expect(text + r"\r\n")
            self.child.expect(r" \r\n")
            return [dict(word=text.strip(), embedding=self.child.before.decode("utf-8").split(" "))]
        
        elif self.mode == Mode.SIMILAR_N:
            self.child.sendline(text)
            self.child.expect('Enter some text: ')
            lines = self.child.before.decode("utf-8").split('\n')
            results = []
            for line in lines[1:-1]:
                line = line.split(" ")
                results.append(dict(word=line[0].strip(), proba=line[1].strip()))
            return results


In [2]:
interface = StarSpaceInterface(Mode.EMBEDDING, starspace_path="../Starspace/", model_path="../data/models/catagoryEmbeddings.bin" )

In [3]:
interface.predict("lenovo")

[{'word': 'lenovo',
  'embedding': ['0.0205605',
   '-0.0534521',
   '-0.0610849',
   '-0.0766744',
   '0.0281063',
   '-0.00221971',
   '-0.143532',
   '0.0161936',
   '-0.0253927',
   '-0.0735328',
   '0.111146',
   '0.145932',
   '-0.0443884',
   '-0.14476',
   '0.11287',
   '-0.16737',
   '0.0106506',
   '0.209247',
   '0.0553636',
   '-0.112312',
   '0.0206334',
   '-0.025157',
   '-0.0298574',
   '-0.104654',
   '0.0270314',
   '0.179489',
   '-0.0248981',
   '-0.0779187',
   '-0.179245',
   '0.0654831',
   '0.0332258',
   '0.0726604',
   '0.000404798',
   '0.111055',
   '0.154709',
   '-0.158142',
   '0.179709',
   '0.0237634',
   '-0.0716949',
   '0.000763156',
   '0.0785529',
   '0.041405',
   '0.146781',
   '-0.121662',
   '-0.214631',
   '0.0253664',
   '0.196819',
   '-0.0983251',
   '-0.120748',
   '0.0965847',
   '0.174694',
   '0.00515233',
   '0.0216982',
   '0.0723112',
   '0.144232',
   '0.0427683',
   '-0.111836',
   '-0.216062',
   '0.0483602',
   '0.0656993',
   '-

In [4]:
interface_pr = StarSpaceInterface(Mode.PREDICT, starspace_path="../Starspace/", model_path="../data/models/catagoryEmbeddings.bin" )

In [5]:
interface_pr.predict("lenovo")

[{'label': '__label__Ostatní_příslušenství_pro_mobilní_telefony',
  'proba': 0.610883}]

In [6]:
interface_si = StarSpaceInterface(Mode.SIMILAR_N, starspace_path="../Starspace/", model_path="../data/models/catagoryEmbeddings.bin" )

In [7]:
interface_si.predict("lenovo")

[{'word': 'lenovo', 'proba': '1'},
 {'word': 'acer', 'proba': '0.761748'},
 {'word': 'Lenovo', 'proba': '0.713689'},
 {'word': '80V60009CK', 'proba': '0.692045'},
 {'word': 'Y6pro', 'proba': '0.685429'}]