In [None]:
from dataset import TextDataset
import numpy as np
import sys, os, json
import gzip
from colored import fg, attr, bg

from env import JerichoEnv
from tqdm import tqdm
from jericho import *

import random



In [None]:
import re
from jericho.util import verb_usage_count
from jericho.template_action_generator import TemplateActionGenerator

class TemplateActionParser(TemplateActionGenerator):
    def __init__(self, rom_bindings):        
        self.templates_alias_dict = {}
        self.verb_to_templates = {}
        self.template2template = {}
        super(TemplateActionParser, self).__init__(rom_bindings)
        
        self.id2template = None
        self.template2id = None
        
        self.additional_templates = ['land']
        self.templates = list(set(self.templates + self.additional_templates))

        self.templates.sort()
        self._compute_template()
        
        BASIC_ACTIONS = 'north/south/west/east/northwest/southwest/northeast/southeast/up/down/enter/exit/take all'.split('/')
        self.BASIC_ACTIONS = {k:1 for k in BASIC_ACTIONS}
        
        self.add_template2template = {}
        for action in list(self.BASIC_ACTIONS.keys()) + self.additional_templates + ['examine OBJ']:
            self.add_template2template[action] = action
        
        
    def _preprocess_templates(self, templates, max_word_length):
        '''
        Converts templates with multiple verbs and takes the first verb.
        '''
        out = []
        vb_usage_fn = lambda verb: verb_usage_count(verb, max_word_length)
        p = re.compile('\S+(/\S+)+')
        for template in templates:
#             print(template)
            if not template:
                continue
            has_alias = True
            while True:
                match = p.search(template)
                if not match:
#                     print('{} not matched'.format(template))
                    has_alias = False
                    break
                    
                verb_alias = match.group().split('/')
                
                verb = max(match.group().split('/'), key=vb_usage_fn)
                verb_template = template[:match.start()] + verb + template[match.end():]
                
                for alias in verb_alias:
                    alias_template = template[:match.start()] + alias + template[match.end():]
                    self.template2template[alias_template] = verb_template
                    
                    if alias in self.verb_to_templates:
                        self.verb_to_templates[alias].append(alias_template)
                    else:
                        self.verb_to_templates[alias] = [alias_template]
                
#                 for alias in verb_alias:
#                     if alias in self.verb_to_templates:
#                         self.verb_to_templates[alias].append(template)
#                     else:
#                         self.verb_to_templates[alias] = [template]
                template = verb_template
                
            ts = template.split()
            if ts[0] in defines.ILLEGAL_ACTIONS:
                continue
            if ts[0] in defines.NO_EFFECT_ACTIONS and len(ts) == 1:
                continue
                
            if not has_alias:
                t_tokens = template.split()
                alias = t_tokens[0]
                verb_alias = [alias]
                if alias in self.verb_to_templates:
                    self.verb_to_templates[alias].append(template)
                else:
                    self.verb_to_templates[alias] = [template]
                    
                self.template2template[template] = template
                
            self.templates_alias_dict[template] = verb_alias
            out.append(template)
        return out
    
    def _compute_template(self):
        self.id2template = {}
        self.template2id = {}
        for i, t in enumerate(self.templates):
            self.id2template[i] = t
            self.template2id[t] = i
        return

    def parse_action(self, action):

        tokens = action.split()
        verb = tokens[0]
#         if verb == 'down':
#             print(verb in self.BASIC_ACTIONS and len(tokens) == 1)

        if (verb in self.BASIC_ACTIONS or verb in self.additional_templates) and len(tokens) == 1:
            return [verb]

        if verb not in self.verb_to_templates:
#             if (verb in self.BASIC_ACTIONS or verb in self.additional_templates) and len(tokens) == 1:
#     #             print(verb)
#                 return [verb]
            if verb == 'examine':
                return ['examine OBJ', ' '.join(tokens[1:])]
            else:
                print('cannot recognize verb:', verb)
                return None
        else:
            templates = self.verb_to_templates[verb]
            for template in templates:
#                 print(template.split())
                t_tokens = template.split()
#                 print(t_tokens)
                
                slot_num = 0
                for t_token in t_tokens:
#                     print(t_token, 'OBJ', t_token == 'OBJ')
                    if t_token == 'OBJ':
                        slot_num += 1
#                 ' \S+'
                re_str = template.replace('OBJ', '(\S+)')
    #             print(re_str)
    #             p = re.compile('\S+(/\S+)+')
                p = re.compile(re_str)

                match = p.search(action)
                if not match:
                    continue
                elif match.group() == action:
                    ret_tuple = [template]
#                     print(slot_num)
                    for i in range(slot_num):
                        ret_tuple.append(match.group(i+1))
                    return ret_tuple
                else:
                    continue
        
        templates = self.verb_to_templates[verb]
        for template in templates:
            t_tokens = template.split()
            slot_num = 0
            for t_id, t_token in enumerate(t_tokens):
                if t_token == 'OBJ':
                    slot_num += 1
                    t_tokens[t_id] = 'OBJ%d'%(slot_num - 1)
#                 ' \S+'

            re_str = ' '.join(t_tokens)
            for i in range(slot_num):
                re_str = re_str.replace('OBJ%d'%(i), '(?P<obj%d>\S+( \S+)*)'%(i))
#             print(re_str)
#             p = re.compile('\S+(/\S+)+')
            p = re.compile(re_str)

            match = p.search(action)
            if not match:
                continue
            elif match.group() == action:
                ret_tuple = [template]
                for i in range(slot_num):
                    ret_tuple.append(match.group('obj%d'%(i)))
                return ret_tuple
            else:
                continue
        return None   
    
# act_par = TemplateActionParser(bindings)
# print(act_par.templates_alias_dict)
# print(act_par.verb_to_templates)

In [None]:
def get_walkthrough(rom_path):
    # Create the environment, optionally specifying a random seed
#     rom_path = "roms/jericho-game-suite/{}".format(game2rom[game_name])

    bindings = load_bindings(rom_path)
    scores = []
    cum_r = 0.0
    step = 0
    if 'walkthrough' in bindings:
        walkthrough = bindings['walkthrough'].split('/')
        seed = bindings['seed']
        env = FrotzEnv(rom_path, seed=seed)
        for act in walkthrough:
            print('step:', step)
            step += 1
            print('act:', act)
            observation, reward, done, info = env.step(act)
            print('obs:', observation)
            scores.append(reward)
            cum_r += reward
            print('curR:', cum_r)
            
    return scores


game_rom_path = "../roms/jericho-game-suite/zork1.z5"
step_scores = get_walkthrough(game_rom_path)

scores = step_scores[:100]
scores = np.array(scores)


In [None]:
'''
testing
zork1 40.01 44.62 34 35 33.6 35 32 41.6 31
library 36.76 46.45 14.3 19 10.0 18 19 19 18
detective 60.28 63.21 207.9 214 246.1 274 320 330 304
balances 55.26 56.49 10 10 9.8 10 10 10 10
pentari 63.89 68.37 50.7 56 48.2 56 56 58 40
ztuu 28.71 29.76 6 9 5 5 5 11.8 5
ludicorp 52.32 59.95 17.8 19 17.6 19 19 22.8 20.6
deephome 8.03 9.27 1 1 1 1 8 6 1
temple 
'''
# m, h, e, e, e, m, m, m, m, m
eval_games = ['zork3', 'anchor', 'detective', 'ztuu', 'temple', 'yomomma', 'jewel', 'gold', 'karn', 'zenon']
eval_games_left = ['anchor', 'yomomma']
train_games_left = ['ludicorp', 'spirit', 'tryst205', 'spellbrkr']

all_games = ['905', 'acorncourt', 'advent', 'adventureland', 'afflicted', 'anchor', 'awaken', 
         'balances', 'deephome', 'detective', 'dragon', 'enchanter', 'gold', 'inhumane', 'jewel', 
         'karn', 'library', 'ludicorp', 'moonlit', 'omniquest', 'pentari', 'reverb', 'snacktime', 
         'sorcerer', 'spellbrkr', 'spirit', 'temple', 'tryst205', 'yomomma', 'zenon', 'zork1', 'zork3', 'ztuu']
games_with_ns_actions = ['library', 'pentari', 'ludicorp', 'deephome', 'advent', 
                         'balances', 'sorcerer', 'tryst205', 'spellbrkr', 'enchanter', 'spirit']

# games = ['zork1', 'zork3', 'enchanter', 'spellbrkr', 'sorcerer']
# games = ['zork2', 'wishbringer']
zork_games = ['zork1', 'zork3', 'enchanter', 'zork2', 'wishbringer', 'sorcerer', 'spellbrkr']

hard_games = ['sorcerer', 'tryst205', 'spellbrkr', 'anchor', 'enchanter', 'spirit']
middle_games = ['ludicorp', 'deephome', 'yomomma', 'advent', 'jewel', 'zork1', 'gold', 'balances',
                'karn', 'zenon', 'zork3']

games = []
for game in all_games:
    if game not in zork_games and game not in eval_games:
        print(game)
        games.append(game)
        
# games = []
# for game in all_games:
#     if game not in eval_games:
#         print(game)
#         games.append(game)
# print(games)
# games = games_with_ns_actions 



In [None]:
#                 state_hash = _get_world_state_hash(self.env)
#                 save = self.env.get_state()
#                 look, _, _, _ = self.env.step('look')
#                 self.env.set_state(save)
#                 inv, _, _, _ = self.env.step('inventory')
#                 self.env.set_state(save)

from jericho.util import clean

def get_full_obs(env, observation):
    save = env.get_state()
    look, _, _, _ = env.step('look')
    env.set_state(save)
    inv, _, _, _ = env.step('inventory')
    env.set_state(save)

    ob = clean(look) + '|' + clean(inv) + '|' + clean(observation)
    
    return ob

def generate_sas_data(game, rom_path):
    # Create the environment, optionally specifying a random seed
#     rom_path = "roms/jericho-game-suite/{}".format(game2rom[game_name])

    bindings = load_bindings(rom_path)
    scores = []
    cum_r = 0.0
    step = 0
    if 'walkthrough' in bindings:
        walkthrough = bindings['walkthrough'].split('/')
        seed = bindings['seed']
        env = FrotzEnv(rom_path, seed=seed)
        
        filein = open('../data/ssa_data/jecc_sup/{}.ssa.wt_traj.txt'.format(game))
        fileout = open('../data/ssa_data/jecc_sup/{}.sas.wt_traj.txt.new'.format(game), 'w')
        
        lines = filein.readlines()
        
        print(len(lines))
        
        for idx, act in enumerate(walkthrough):
            print('step:', step)
            step += 1
            print('act:', act)
            
            state_save = env.get_state()
            
            observation, reward, done, info = env.step(act)
            print('obs:', observation)
            scores.append(reward)
            cum_r += reward
            print('curR:', cum_r)          
            new_ob = get_full_obs(env, observation)
#             print(new_ob)
            
            state_prim_save = env.get_state()
            
            if idx > 0:
                for valid_act_group in wt_ssa_data['valid_actions']:
                    valid_act_tuple = valid_act_group[0]
                    valid_act = valid_act_tuple['a']
                    
                    env.set_state(state_save)
                    observation, reward, done, info = env.step(valid_act)
                    ob = get_full_obs(env, observation)
#                     print('[' + valid_act + ']: ' + ob)
                    valid_act_tuple['observations'] = ob
    
#                 print(wt_ssa_data['valid_actions'])
                fileout.write(json.dumps(wt_ssa_data) + '\n')
                env.set_state(state_prim_save)
                    
            
            if idx < len(walkthrough) - 1 and idx < len(lines) - 1:
#                 line = filein.readline()
                line = lines[idx - 1]
                wt_ssa_data = json.loads(line)

                ssa_obs = '|'.join(wt_ssa_data['observations'].split('|')[0:3])

    #             print(ob == ssa_obs)
                if new_ob != ssa_obs:
                    print('ERROR: different obs')
                    print(new_ob)
                    print(ssa_obs)
                    return
            
            if idx == len(lines) or idx == len(walkthrough):
                break
                
        filein.close()
        fileout.close()
            
    return scores


def generate_sas_data_new(game, rom_path):
    # Create the environment, optionally specifying a random seed
#     rom_path = "roms/jericho-game-suite/{}".format(game2rom[game_name])

    bindings = load_bindings(rom_path)
    scores = []
    cum_r = 0.0
    step = 0
    if 'walkthrough' in bindings:
        walkthrough = bindings['walkthrough'].split('/')
        seed = bindings['seed']
        env = FrotzEnv(rom_path, seed=seed)
        
        filein = open('../data/ssa_data/jecc_sup/{}.ssa.wt_traj.txt'.format(game))
        fileout = open('../data/ssa_data/jecc_sup/{}.sas.wt_traj.txt.new'.format(game), 'w')
        
        lines = filein.readlines()
        
        print(len(lines))
        
        for idx, act in enumerate(walkthrough):
#             print('step:', step)
            step += 1
#             print('act:', act)
            
            state_save = env.get_state()
            
            observation, reward, done, info = env.step(act)
#             print('obs:', observation)
            scores.append(reward)
            cum_r += reward
#             print('curR:', cum_r)          
            new_ob = get_full_obs(env, observation)
#             print(new_ob)
            
            state_prim_save = env.get_state()
            
            if idx > 0:
                if isinstance(wt_ssa_data['valid_actions'][0], dict):
                    wt_ssa_data['valid_actions'] = [wt_ssa_data['valid_actions']]
                for valid_act_group in wt_ssa_data['valid_actions']:
                    valid_act_tuple = valid_act_group[0]
                    valid_act = valid_act_tuple['a']
                    
                    env.set_state(state_save)
                    observation, reward, done, info = env.step(valid_act)
                    ob = get_full_obs(env, observation)
#                     print('[' + valid_act + ']: ' + ob)
                    valid_act_tuple['observations'] = ob
    
#                 print(wt_ssa_data['valid_actions'])
                fileout.write(json.dumps(wt_ssa_data) + '\n')
                env.set_state(state_prim_save)

            if idx == len(lines) or idx == len(walkthrough):
                break
            
            line = lines[idx]
            wt_ssa_data = json.loads(line)

            ssa_obs = '|'.join(wt_ssa_data['observations'].split('|')[0:3])

#             print(ob == ssa_obs)
            if new_ob != ssa_obs:
                print('step:', step)
                print('ERROR: different obs')
                print(new_ob)
                print(ssa_obs)
                return

                
        filein.close()
        fileout.close()
            
    return scores

# games = ['zork1', 'zork3', 'enchanter', 'zork2', 'wishbringer', 'sorcerer']
games = ['wishbringer']

print('#number of games: {}'.format(len(games)))

roms = os.listdir('../roms/jericho-game-suite/')
game2rom = {}
logs = []
for game in games:
    for rom in roms:
        if rom.startswith(game + '.z'):
            game2rom[game] = rom
#             print('find {} for {}'.format(rom, game))
            logs.append('find {} for {}'.format(rom, game))
    if game not in game2rom:
        print('cannot find rom for {}'.format(game))
                        
print('#number of roms founds: {}'.format(len(logs)))

for game in games:
    print('working on {}'.format(game))
    game_rom_path = "../roms/jericho-game-suite/{}".format(game2rom[game])
    step_scores = generate_sas_data_new(game, game_rom_path)
    print(np.sum(np.array(step_scores)))
#     break
#     scores = step_scores[:100]
#     scores = np.array(scores)


In [None]:
print(len(lines))

In [None]:

# def generate_output_tuple(observation, info):

#     output_dict = {}
# #     output_dict['actions'] = []
#     output_dict['valid_actions'] = []
#     output_dict['observations'] = observation

# #     for a in info['act']:
# #         new_dict = {}
# #         new_dict['a'] = a.action
# #         new_dict['t'] = a.template_id
# #         new_dict['o'] = a.obj_ids
# #     #     obj_ids = [str(x) for x in a.obj_ids]
# #     #     print(a.action + '\t' + str(a.template_id) + '\t' + ' '.join(obj_ids))
# #         output_dict['actions'].append(new_dict)

#     if isinstance(info['valid_act'][0], str):
#         new_dict = {}
#         new_dict['a'] = info['valid_act'][0]
#         output_dict['valid_actions'].append(new_dict)
        
#     else:
#         for a_list in info['valid_act']:
#             new_list = []
#             for a in a_list:
#                 new_dict = {}
#                 new_dict['a'] = a.action
#                 new_dict['t'] = a.template_id
#                 new_dict['o'] = a.obj_ids
                
#                 new_list.append(new_dict)
#         #     obj_ids = [str(x) for x in a.obj_ids]
#         #     print(a.action + '\t' + str(a.template_id) + '\t' + ' '.join(obj_ids))
#             output_dict['valid_actions'].append(new_list)
        
#     return output_dict

# game_max_scores = []

# games = ['zork1']

# for game in games:
    
#     if os.path.isfile('../data/ssa_data/zork_universe_sup/{}.sas.wt_traj.txt'.format(game)):
#         continue
    
#     print('generating trajectary for game: {}'.format(game))
# #     continue

#     rom_path = "../roms/jericho-game-suite/{}".format(game2rom[game]) # "../roms/jericho-game-suite/zork1.z5"

#     bindings = load_bindings(rom_path)
#     scores = []
#     if 'walkthrough' in bindings:
#         walkthrough = bindings['walkthrough'].split('/')
#         seed = bindings['seed']

#         filein = open('../data/ssa_data/zork_universe_sup/{}.ssa.wt_traj.txt'.format(game))
#         fileout = open('../data/ssa_data/zork_universe_sup/{}.sas.wt_traj.txt'.format(game), 'w')

#         env = JerichoEnv(rom_path, seed=seed)
#     #     env = FrotzEnv(rom_path, seed=seed)
    
#         print(walkthrough)
#         for idx, act in enumerate(walkthrough):

#             observation, reward, done, info = env.env.step(act)
            
#             line = filein.readline()
#             wt_ssa_data = json.loads(line)
#             print(wt_ssa_data)
            
#             break
            
# #             observation, reward, done, info = env.step(act)
#     #         observation, reward, done, info = env.step(act, parallel=False)
#             scores.append(reward)

#     #         if len(info['valid_act']) == 3 and info['valid_act'][0] == 'wait' and info['valid_act'][1] == 'yes' and info['valid_act'][2] == 'no':
    
#             if len(info['valid_act']) == 0:
#                 if idx == len(walkthrough) - 1:
#     #                 print(observation)
#                     break
#                 else:
#                     info['valid_act'] = [walkthrough[idx+1]]

#             output_dict = generate_output_tuple(observation, info)    
#             fileout.write(json.dumps(output_dict) + '\n')
#     #         break

#         print('Total Score', info['score'])
#         game_max_scores.append(info['score'])

#         fileout.close()