In [1]:
####### Imports #######

In [2]:
import importlib

from tangle.models.baseline_constants import MODEL_PARAMS
from tangle.models.utils.language_utils import ALL_LETTERS_POETS
from tangle.lab.lab_transaction_store import LabTransactionStore

ModuleNotFoundError: No module named 'tangle'

In [None]:
####### General Config #######

In [None]:
lab_seed = 0

experiment_name = 'poets-stacked_lstm-2'
experiment_config = '0'
experiment_transaction_id = '8e5a1029b9ee9d0abb6e57c0f3553069ac824bdc'

class ModelConfig:
    def __init__(self):
        self.dataset = None
        self.model = None
        self.lr = None
        self.num_epochs = None
        self.batch_size = None

model_config = ModelConfig()
model_config.dataset = 'poets'      # no default
model_config.model = 'stacked_lstm' # no default
model_config.lr = 0.8               # default: -1
model_config.num_epochs = 1         # default: 1
model_config.batch_size = 10        # default: 10

In [None]:
####### Logic #######

In [None]:
# copied from lab.py

def create_client_model(seed, model_config):
    model_path = '.%s.%s' % (model_config.dataset, model_config.model)
    mod = importlib.import_module(model_path, package='tangle.models')
    ClientModel = getattr(mod, 'ClientModel')

    # Create 2 models
    model_params = MODEL_PARAMS['%s.%s' % (model_config.dataset, model_config.model)]
    if model_config.lr != -1:
        model_params_list = list(model_params)
        model_params_list[0] = model_config.lr
        model_params = tuple(model_params_list)

    model = ClientModel(seed, *model_params)
    model.num_epochs = model_config.num_epochs
    model.batch_size = model_config.batch_size
    return model

In [None]:
def test(client_model, transaction_store, transaction_id, data):
    # Load and set params for model (based on transaction id)
    model_params = transaction_store.load_transaction_weights(transaction_id)
    client_model.set_params(model_params)
    
    # Test the model
    results = client_model.test(data)
    len_results = len(results['additional_metrics'][0])
    
    # Extract predicted and expected results
    predicted = [ALL_LETTERS_POETS[results['additional_metrics'][0][i]] for i in range(len_results)]
    expected = [ALL_LETTERS_POETS[results['additional_metrics'][1][i]] for i in range(len_results)]
    return predicted, expected

In [None]:
# Create transaction store
transaction_store = LabTransactionStore(f'../experiments/{experiment_name}/config_{experiment_config}/tangle_data')

# Create model
client_model = create_client_model(lab_seed, model_config)

In [None]:
# Enter test_data here (or load from file, ...)
test_data = {
    'x': [
        "Das ist ein wunderbarer Beispieltext, der exakt 80 Zeichen lang ist. Gleich komm",
        "sagte er und ging auf sie zu. \"Du hast mir nichts zu sagen!\" Er war außer sich v",
        "ch prangt, Wenn sie das goldne Vlies erlangt, Ihr die Kabiren. Wenn sie das gold",
    ],
    'y': [
        "t",
        "o",
        "e"
    ]
}

In [None]:
predicted, expected = test(client_model, transaction_store, experiment_transaction_id, test_data)

In [None]:
for t_data, expected_char, predicted_char in zip(test_data['x'], expected, predicted):
    print("For '%s' expected '%s', got '%s'" % (t_data, expected_char, predicted_char))