In [6]:
import pickle
from typing import List

from lazyme import find_files
from IPython.display import clear_output

In [20]:
import torch 

from getalp.wsd.predicter import Predicter
from getalp.wsd.model import Model, ModelConfig, DataConfig

In [3]:
def read_sample_x_from_string(string: str, feature_count: int, clear_text: List[bool]):
    sample_x: List = [[] for _ in range(feature_count)]
    for word in string.split():
        word_features = word.split('/')
        for i in range(feature_count):
            if clear_text[i]:
                sample_x[i].append(word_features[i].replace("<slash>", "/"))
            else:
                sample_x[i].append(int(word_features[i]))
    for i in range(feature_count):
        if not clear_text[i]:
            sample_x[i] = torch_tensor(sample_x[i], dtype=torch_long, device=cpu_device)
    return sample_x

In [50]:
class Predictor:
    def __init__(self, config_filename, ensemble_filenames, outvocab_filename, clear_text=True):
        # Really odd way of loading config -_-
        self.data_config = DataConfig()
        self.data_config.load_from_file(config_filename)
        
        self.model_config = ModelConfig(self.data_config)
        self.model_config.load_from_file(config_filename)
        
        # Even more odd to read function arguments. 
        if clear_text:
            self.model_config.data_config.input_clear_text = [
                True for _ in range(self.model_config.data_config.input_features)
            ]

        to_disambiguate = False if self.data_config.output_features <= 0 else True
        to_translate = False if self.data_config.output_translations <= 0 else True
        
        # Load ensembles.
        self.ensemble = self.load_ensemble(ensemble_filenames)
        
        # Weird predictor object that isn't Object oriented at all -_-|||
        # More like an empty shell. 
        self.predictor = Predicter()
        
        self.output_vocab = self.load_sense_inventory(outvocab_filename)
        
    def load_sense_inventory(self, outvocab_filename):
        sense_inventory = []
        with open(outvocab_filename) as fin:
            for line in fin:
                sense_inventory.append(line.strip())
        return sense_inventory
        
    
    def load_ensemble(self, ensemble_filenames):
        # Load ensemble.
        ensemble = []
        for fn in ensemble_filenames:
            model = Model(self.model_config)
            model.load_model_weights(fn)
            ensemble.append(model)
        return ensemble
        
    def process_input(self, text):
        """ Converts text to predictor's input """
        sample = read_sample_x_from_string(
             text, 
             feature_count=self.data_config.input_features, 
            clear_text=self.data_config.input_clear_text
        )
        # Proessed input
        self.predictor.preprocess_sample_x(self.ensemble, sample)
        return sample

    def disambiguate(self, text):
        batch = [self.process_input(text)]
        # Output tensor => [l x v]
        tensor_result = self.predictor.predict_ensemble_wsd_on_batch(self.ensemble, batch)[0]
        # Get the output vocab indices. 
        sense_indices = [int(torch.argmax(sense_tensor, dim=0)) for sense_tensor in tensor_result]
        senses = [self.output_vocab[i] for i in sense_indices]
        return senses

In [70]:
model_filenames = sorted(find_files("models/", "cpu*"))

p = Predictor('config.json', model_filenames, 'output_vocabulary0')

# Fancy thing to clear the logging outputs away.
clear_output()

In [71]:
from nltk.corpus import wordnet as wn

output_senses = p.disambiguate('cat like fish')

for ss in output_senses:
    pos = ss[0]
    offset = int(ss[1:])
    print(wn._synset_from_pos_and_offset(pos, offset))

Synset('person.n.01')
Synset('have.v.01')
Synset('category.n.02')


  Function _synset_from_pos_and_offset() has been deprecated.  Use
  public method synset_from_pos_and_offset() instead
  
