In [None]:
"""
Display examples of each grammar

Run the first cell to generate grammars, then the second cell to output the examples.
"""

import random
import torch
from collections import defaultdict
import importlib
import string
import numpy as np
import pandas as pd

from ulfs import tensor_utils

from mll import mem_common

grammars = 'Compositional,Permute,Cumrot,RandomProj,ShuffleWordsDet,Holistic'.split(',')
num_atts = 5
num_vals = 10
# c_len = 4 * num_atts
tokens_per_meaning = 4
vocab_size = 4

# grammars_to_show = ['comp', 'perm', 'proj', 'rot', 'shufdet', 'hol']

seed = 123

def tensor_to_str(t):
    return tensor_utils.tensor_to_str(t, vocab=string.ascii_lowercase)

# df_rows = []
utt_by_grammar_by_meaning = defaultdict(dict)
for grammar in grammars:
    print(grammar)
    corruption = None
    grammar_family = grammar
    if grammar not in ['Compositional', 'Holistic']:
        corruption = grammar
        grammar_family = 'Compositional'
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    Grammar = getattr(mem_common, f'{grammar_family}Grammar')
    grammar_object = Grammar(
        num_meaning_types=num_atts,
        tokens_per_meaning=tokens_per_meaning,
        meanings_per_type=num_vals,
        vocab_size=vocab_size,
        corruptions=corruption
    )
    _utts = grammar_object.utterances_by_meaning
    _meanings = grammar_object.meanings
    print('_utts.size()', _utts.size(), '_meanings.size()', _meanings.size())
#     print('_utts[]')
    _N = _meanings.size(0)
    for meaning_idx in [0, 1, 10, 100]:
        _meaning = tuple(_meanings[meaning_idx].tolist())
        _utt = tensor_to_str(_utts[meaning_idx])
        print(meaning_idx, _utt, _meaning)
        grammar_disp = {
            'Compositional': 'comp',
            'Permute': 'perm',
            'RandomProj': 'proj',
            'ShuffleWordsDet': 'shufdet',
            'Cumrot': 'rot',
            'Holistic': 'hol'
        }[grammar]
        utt_by_grammar_by_meaning[_meaning][grammar_disp] = _utt

In [None]:
meanings_to_show = list(utt_by_grammar_by_meaning.keys())
print('meanings_to_show', meanings_to_show)
# print('_meanings', utt_by_grammar_by_meaning.keys())

df_rows = []
for i, meaning in enumerate(meanings_to_show):
    row = {'meaning': meaning}
#     latex_row = 
    for grammar, utt in utt_by_grammar_by_meaning[meaning].items():
        row[grammar] = utt
    df_rows.append(row)
df = pd.DataFrame(df_rows)
print(df)

latex_str = ''
latex_str += '\\toprule \n'
latex_str += ' & '.join([f'\\textsc{{{col}}}' for col in df.columns]) + ' \\\\ \n'
latex_str += '\\midrule \n'
for i, meaning in enumerate(meanings_to_show):
    latex_str += str(meaning)
    for grammar, utt in utt_by_grammar_by_meaning[meaning].items():
        latex_str += ' & ' + str(utt)
    latex_str += ' \\\\ \n'
latex_str += '\\bottomrule \n'

print(latex_str)

latex_str = ''
# latex_str += 'meaning & ' + ' & '.join([for grammar in ])
latex_str += '\\toprule \n'
latex_str += ' & \multicolumn{' + str(len(meanings_to_show)) + '}{l}{Objects} \\\\ \n'
latex_str += ' & ' + ' & '.join([str(meaning) for meaning in meanings_to_show]) + ' \\\\ \n'
latex_str += '\\midrule \n'
for grammar, utt in utt_by_grammar_by_meaning[meaning].items():
    latex_str += grammar
    for i, meaning in enumerate(meanings_to_show):
        latex_str += ' & ' + str(utt)
    latex_str += ' \\\\ \n'
latex_str += '\\bottomrule \n'

print(latex_str)